Skip to content
Snippets Groups Projects
Commit 781f0b13 authored by oscarchaufour's avatar oscarchaufour
Browse files

Update read_cifar.py

cifar extraction
parent 5f1d0fcf
No related branches found
No related tags found
No related merge requests found
......@@ -5,8 +5,44 @@ Created on Fri Oct 20 16:04:49 2023
@author: oscar
"""
def read_cifar_batch() :
return
import os
import numpy as np
import pickle
import pickle
def read_cifar_batch(batch_path):
with open(batch_path, 'rb') as f:
batch = pickle.load(f, encoding='bytes')
data = np.array(batch.get(b'data'))
labels = np.array(batch.get(b'labels'))
return data, labels
def read_cifar (batch_dir) :
data_batches = []
label_batches = []
for i in range(1,6) :
batch_filename = f'data_batch_{i}'
batch_path = os.path.join(batch_dir, batch_filename)
data, labels = read_cifar_batch(batch_path)
data_batches.append(data)
label_batches.append(labels)
test_batch_filename = 'test_batch'
test_batch_path = os.path.join(batch_dir, test_batch_filename)
data_test, labels_test = read_cifar_batch(test_batch_path)
data_batches.append(data_test)
label_batches.append(labels_test)
data = np.concatenate(data_batches, axis=0)
labels = np.concatenate(label_batches, axis=0)
return data, labels
if __name__ == "__main__":
pass
file = "./data/cifar-10-python/data_batch_1"
read_cifar_batch(file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment