# -*- coding: utf-8 -*-
"""
Created on Fri Oct 20 16:04:49 2023

@author: oscar
"""

import os
import numpy as np
import pickle

import pickle

def read_cifar_batch(batch_path):
    with open(batch_path, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
        
        data = np.array(batch.get(b'data'))
        labels = np.array(batch.get(b'labels'))

    return data, labels

def read_cifar (batch_dir) :
    data_batches = []
    label_batches = []
    
    for i in range(1,6) :
        batch_filename = f'data_batch_{i}'
        batch_path = os.path.join(batch_dir, batch_filename)
        data, labels = read_cifar_batch(batch_path)
        data_batches.append(data)
        label_batches.append(labels)
        
        test_batch_filename = 'test_batch'
        test_batch_path = os.path.join(batch_dir, test_batch_filename)
        data_test, labels_test = read_cifar_batch(test_batch_path)
        data_batches.append(data_test)
        label_batches.append(labels_test)
        
        data = np.concatenate(data_batches, axis=0)
        labels = np.concatenate(label_batches, axis=0)

    return data, labels

def split_dataset(data, labels, split) : 
    
    number_total = data.shape[0]
    number_train = int(number_total * split)
    indices = np.arange(number_total)
    np.random.shuffle(indices)
    indices_train = indices[:number_train]
    indices_test = indices[number_train:]
    data_train = data[indices_train]
    labels_train = labels[indices_train]
    data_test = data[indices_test]
    labels_test = labels[indices_test]
    
    return(data_train, labels_train, data_test, labels_test)

if __name__ == "__main__":
    file = "./data/cifar-10-python/"
    data, labels = read_cifar(file)
    res = split_dataset(data, labels, 0.8)