Skip to content
Snippets Groups Projects
Commit 2d858147 authored by pierre-cau's avatar pierre-cau
Browse files

read cifar

parent 54d52fa2
No related branches found
No related tags found
No related merge requests found
from utils import *
if __name__ == "__main__":
# Load CIFAR data
data, labels = read_cifar(r"../data/cifar-10-batches-py")
\ No newline at end of file
# Author: Pierre CAU
# Date: 2024
ROOT = r"../../.."
from .read_cifar import *
File added
File added
File added
# Author: Pierre CAU
# Date: 2024
import os
import numpy as np
import pickle as pkl
def read_cifar_batch(file_path):
with open(file_path, 'rb') as file:
batch = pkl.load(file, encoding='bytes')
data = np.array(batch[b'data'], dtype=np.float32)
labels = np.array(batch[b'labels'], dtype=np.int64)
return data, labels
def read_cifar(data_dir):
"""
Read all CIFAR batches and return combined data and labels.
Parameters
----------
data_dir : str
Path to the directory containing the CIFAR batches.
Returns
-------
data : np.ndarray
Array of images of shape (batch_size, data_size) with dtype np.float32.
labels : np.ndarray
Array of labels of shape (batch_size,) with dtype np.int64.
"""
batch_files = [f for f in os.listdir(data_dir) if f.startswith("data_batch") or f.endswith("batch_test")]
all_data = []
all_labels = []
for batch_file in batch_files:
batch_path = os.path.join(data_dir, batch_file)
data, labels = read_cifar_batch(batch_path)
all_data.append(data)
all_labels.append(labels)
try :
data = np.vstack(all_data).astype(np.float32)
labels = np.hstack(all_labels).astype(np.int64)
except MemoryError:
for removed in range(1,len(all_data)):
done = False
try:
data = np.vstack(all_data[:-removed]).astype(np.float32)
labels = np.hstack(all_labels[:-removed]).astype(np.int64)
print(f"""
=============================================================================
Memory Error occurred when trying to load all the data.
Removed the last {removed} batches and successfully loaded the data.
Loaded batches: {batch_files[:-removed]}
Data shape: {data.shape}, Labels shape: {labels.shape}
Memory usage: {data.nbytes / 1024**2:.2f} MB
=============================================================================
""")
done = True
break
except MemoryError:
pass
if not done:
raise MemoryError("Not enough memory to load the data")
return data, labels
if __name__ == "__main__":
ROOT = os.path.join(os.path.dirname(__file__), "..", "..")
data_folder = r"data\cifar-10-batches-py"
data_folder = os.path.normpath(os.path.join(ROOT, data_folder)) # Normalize path
data_path = os.path.join(ROOT, data_folder) # Path to the data folder
data, labels = read_cifar(data_folder)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment