Skip to content
Snippets Groups Projects
Commit 51ab6a8a authored by Duperret Loris's avatar Duperret Loris
Browse files

Delete read_cifar.py

parent ac4c1f46
No related branches found
No related tags found
No related merge requests found
import numpy as np
import pickle
import os
def read_cifar_batch(batch_path):
with open(batch_path, 'rb') as file:
# Load the batch data
batch_data = pickle.load(file, encoding='bytes')
# Extract data and labels from the batch
data = batch_data[b'data'] # CIFAR-10 data
labels = batch_data[b'labels'] # Class labels
# Convert data and labels to the desired data types
data = np.array(data, dtype=np.float32)
labels = np.array(labels, dtype=np.int64)
return data, labels
def read_cifar(directory_path):
data_batches = []
label_batches = []
# Iterate through the batch files in the directory
for batch_file in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'test_batch']:
batch_path = os.path.join(directory_path, batch_file)
with open(batch_path, 'rb') as file:
# Load the batch data
batch_data = pickle.load(file, encoding='bytes')
# Extract data and labels from the batch
data = batch_data[b'data'] # CIFAR-10 data
labels = batch_data[b'labels'] # Class labels
data_batches.append(data)
label_batches.extend(labels)
# Combine all batches into a single data matrix and label vector
data = np.concatenate(data_batches, axis=0)
labels = np.array(label_batches, dtype=np.int64)
# Convert data to the desired data type
data = data.astype(np.float32)
return data, labels
def split_dataset(data, labels, split):
# Check if the split parameter is within the valid range (0 to 1)
if split < 0 or split > 1:
raise ValueError("Split must be a float between 0 and 1.")
# Get the number of samples in the dataset
num_samples = len(data)
# Calculate the number of samples for training and testing
num_train_samples = int(num_samples * split)
num_test_samples = num_samples - num_train_samples
# Create a random shuffle order for the indices
shuffle_indices = np.random.permutation(num_samples)
# Use the shuffled indices to split the data and labels
data_train = data[shuffle_indices[:num_train_samples]]
labels_train = labels[shuffle_indices[:num_train_samples]]
data_test = data[shuffle_indices[num_train_samples:]]
labels_test = labels[shuffle_indices[num_train_samples:]]
return data_train, labels_train, data_test, labels_test
if __name__ == '__main__':
batch_path = "data/cifar-10-python\cifar-10-batches-py\data_batch_1" # Update with your path
data, labels = read_cifar_batch(batch_path)
print("Data shape:", data.shape)
print("Labels shape:", labels.shape)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment