diff --git a/read_cifar.py b/read_cifar.py index 2aeb021bed6dfd1767498bd1196c82e0eef67a2b..6e0c7cd6d67feb628273c9c8332a1c6029d8e4e4 100644 --- a/read_cifar.py +++ b/read_cifar.py @@ -1,5 +1,6 @@ import numpy as np import pickle + def read_cifar_batch(file): with open(file, 'rb') as fo: dict = pickle.load(fo, encoding='bytes') @@ -12,24 +13,49 @@ def read_cifar_batch(file): #print(vect1) def read_cifar(directory): - all_data = [] - all_labels = [] + data = [] + labels = [] for i in range(1,6): data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}') - all_data.append(data_v) - all_labels.append(labels_v) + data.append(data_v) + labels.append(labels_v) data_v, labels_v = read_cifar_batch(f'{directory}/test_batch') - all_data.append(data_v) - all_labels.append(labels_v) + data.append(data_v) + labels.append(labels_v) + + data = np.concatenate(data, axis = 0) + labels = np.concatenate(labels, axis = 0) + + return(data, labels) + +def split_dataset(data, labels, split): + + data_size = data.shape[0] + train_size = int(data_size * split) + indices = np.arange(data_size) + np.random.shuffle(indices) + + indices_train = indices[:train_size] + indices_test = indices[train_size:] + data_train = data[indices_train] + labels_train = labels[indices_train] + data_test = data[indices_test] + labels_test = labels[indices_test] + + return(data_train, labels_train, data_test, labels_test) - all_data = np.concatenate(all_data, axis = 0) - all_labels = np.concatenate(all_labels, axis = 0) + - return(all_data, all_labels) +if __name__ == "__main__": + #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1") + #print(vect1) + #vect2= read_cifar("data/cifar-10-batches-py") + #print(vect2) -#vect2= read_cifar("data/cifar-10-batches-py") -#print(vect2) + pair = read_cifar("data/cifar-10-batches-py") + vect3= split_dataset(pair[0], pair[1], 0.6) + print(vect3)