import numpy as np import pickle import os def read_cifar_batch(batch_path): with open(batch_path, 'rb') as file: # Load the batch data batch_data = pickle.load(file, encoding='bytes') # Extract data and labels from the batch data = batch_data[b'data'] # CIFAR-10 data labels = batch_data[b'labels'] # Class labels # Convert data and labels to the desired data types data = np.array(data, dtype=np.float32) labels = np.array(labels, dtype=np.int64) return data, labels def read_cifar(directory_path): data_batches = [] label_batches = [] # Iterate through the batch files in the directory for batch_file in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'test_batch']: batch_path = os.path.join(directory_path, batch_file) with open(batch_path, 'rb') as file: # Load the batch data batch_data = pickle.load(file, encoding='bytes') # Extract data and labels from the batch data = batch_data[b'data'] # CIFAR-10 data labels = batch_data[b'labels'] # Class labels data_batches.append(data) label_batches.extend(labels) # Combine all batches into a single data matrix and label vector data = np.concatenate(data_batches, axis=0) labels = np.array(label_batches, dtype=np.int64) # Convert data to the desired data type data = data.astype(np.float32) return data, labels def split_dataset(data, labels, split): # Check if the split parameter is within the valid range (0 to 1) if split < 0 or split > 1: raise ValueError("Split must be a float between 0 and 1.") # Get the number of samples in the dataset num_samples = len(data) # Calculate the number of samples for training and testing num_train_samples = int(num_samples * split) num_test_samples = num_samples - num_train_samples # Create a random shuffle order for the indices shuffle_indices = np.random.permutation(num_samples) # Use the shuffled indices to split the data and labels data_train = data[shuffle_indices[:num_train_samples]] labels_train = labels[shuffle_indices[:num_train_samples]] data_test = data[shuffle_indices[num_train_samples:]] labels_test = labels[shuffle_indices[num_train_samples:]] return data_train, labels_train, data_test, labels_test if __name__ == '__main__': batch_path = "data/cifar-10-python\cifar-10-batches-py\data_batch_1" # Update with your path data, labels = read_cifar_batch(batch_path) print("Data shape:", data.shape) print("Labels shape:", labels.shape)