import numpy as np
import pickle
import os

def read_cifar_batch(batch_path):
    with open(batch_path, 'rb') as file:
        # Load the batch data
        batch_data = pickle.load(file, encoding='bytes')

    # Extract data and labels from the batch
    data = batch_data[b'data']  # CIFAR-10 data
    labels = batch_data[b'labels']  # Class labels

    # Convert data and labels to the desired data types
    data = np.array(data, dtype=np.float32)
    labels = np.array(labels, dtype=np.int64)

    return data, labels


def read_cifar(directory_path):
    data_batches = []
    label_batches = []

    # Iterate through the batch files in the directory
    for batch_file in ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5', 'test_batch']:
        batch_path = os.path.join(directory_path, batch_file)

        with open(batch_path, 'rb') as file:
            # Load the batch data
            batch_data = pickle.load(file, encoding='bytes')

        # Extract data and labels from the batch
        data = batch_data[b'data']  # CIFAR-10 data
        labels = batch_data[b'labels']  # Class labels

        data_batches.append(data)
        label_batches.extend(labels)

    # Combine all batches into a single data matrix and label vector
    data = np.concatenate(data_batches, axis=0)
    labels = np.array(label_batches, dtype=np.int64)

    # Convert data to the desired data type
    data = data.astype(np.float32)

    return data, labels

def split_dataset(data, labels, split):
    # Check if the split parameter is within the valid range (0 to 1)
    if split < 0 or split > 1:
        raise ValueError("Split must be a float between 0 and 1.")

    # Get the number of samples in the dataset
    num_samples = len(data)

    # Calculate the number of samples for training and testing
    num_train_samples = int(num_samples * split)
    num_test_samples = num_samples - num_train_samples

    # Create a random shuffle order for the indices
    shuffle_indices = np.random.permutation(num_samples)

    # Use the shuffled indices to split the data and labels
    data_train = data[shuffle_indices[:num_train_samples]]
    labels_train = labels[shuffle_indices[:num_train_samples]]
    data_test = data[shuffle_indices[num_train_samples:]]
    labels_test = labels[shuffle_indices[num_train_samples:]]

    return data_train, labels_train, data_test, labels_test








if __name__ == '__main__':
    batch_path = "data/cifar-10-python\cifar-10-batches-py\data_batch_1"  # Update with your path
    data, labels = read_cifar_batch(batch_path)
    print("Data shape:", data.shape)
    print("Labels shape:", labels.shape)