Skip to content
Snippets Groups Projects
Commit cad0a258 authored by pierre-cau's avatar pierre-cau
Browse files

split_data

parent 2d858147
No related branches found
No related tags found
No related merge requests found
......@@ -4,3 +4,13 @@ from utils import *
if __name__ == "__main__":
# Load CIFAR data
data, labels = read_cifar(r"../data/cifar-10-batches-py")
print(f"Data shape: {data.shape}, Labels shape: {labels.shape}\n")
# Split the dataset
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
print(f"Split the dataset with a {0.8} split factor:")
print(f" - Training data shape: {data_train.shape}, Training labels shape: {labels_train.shape}")
print(f" - Testing data shape: {data_test.shape}, Testing labels shape: {labels_test.shape}")
\ No newline at end of file
......@@ -3,6 +3,7 @@
ROOT = r"../../.."
from .read_cifar import *
from .split_data import *
# Author: Pierre CAU
# Date: 2024
import numpy as np
def split_dataset(data, labels, split):
"""
Split the dataset into a training set and a test set.
Parameters
----------
data : np.ndarray
Array of data samples.
labels : np.ndarray
Array of labels corresponding to the data samples.
split : float
A float between 0 and 1 which determines the split factor of the training set with respect to the test set.
Returns
-------
data_train : np.ndarray
Training data.
labels_train : np.ndarray
Corresponding labels for the training data.
data_test : np.ndarray
Testing data.
labels_test : np.ndarray
Corresponding labels for the testing data.
"""
assert data.shape[0] == labels.shape[0], "Data and labels must have the same number of samples"
assert 0 < split < 1, "Split must be a float between 0 and 1"
# Shuffle the data and labels in unison
indices = np.arange(data.shape[0])
np.random.shuffle(indices)
data = data[indices]
labels = labels[indices]
# Calculate the split index
split_index = int(data.shape[0] * split)
# Split the data and labels
data_train = data[:split_index]
labels_train = labels[:split_index]
data_test = data[split_index:]
labels_test = labels[split_index:]
return data_train, labels_train, data_test, labels_test
# Example usage
if __name__ == "__main__":
# Example data and labels
data = np.random.rand(100, 10)
labels = np.random.randint(0, 2, 100)
print(f"Data shape: {data.shape}, Labels shape: {labels.shape}")
# Split the dataset
data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)
print(f"Training data shape: {data_train.shape}, Training labels shape: {labels_train.shape}")
print(f"Testing data shape: {data_test.shape}, Testing labels shape: {labels_test.shape}")
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment