import numpy as np
import pickle
from sklearn.model_selection import train_test_split 
import pandas as pd


import pickle

def read_cifar_batch(batch_path):

  with open(batch_path, "rb") as f:
    batch = pickle.load(f, encoding="bytes") 

  data = np.array(batch[b'data'], dtype=np.float32)
  labels = np.array(batch[b'labels'], dtype=np.int64)
  
  return data, labels


def read_cifar(path):
    data = []
    labels = []

    #Add the 5 batches
    for i in range(1,6):
        data_temp, labels_temp = read_cifar_batch(f'{path}/data_batch_{i}')
        data.append(data_temp)
        labels.append(labels_temp)

    #Add the test batches
    data_temp, labels_temp = read_cifar_batch(f'{path}/test_batch')
    data.append(data_temp)
    labels.append(labels_temp)

    #Concatenate all the batches to create a big one
    data = np.concatenate(data, axis = 0)
    labels = np.concatenate(labels, axis = 0)

    return(data, labels)

def split_dataset(data, labels, split):
    X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=(1 - split), random_state=0)

    return(X_train, X_test, y_train, y_test)


if __name__== '__main__':
   
   data, labels = read_cifar_batch('Data/cifar-10-batches-py/data_batch_1')
   data, labels = read_cifar('/Users/milancart/Documents/GitHub/image-classification/Data/cifar-10-batches-py')
   X_train, X_test, y_train, y_test = split_dataset(data, labels, 0.8)
   print(X_train, X_test, y_train, y_test)