# -*- coding: utf-8 -*-
"""
Created on Fri Oct 20 16:04:49 2023

@author: oscar
"""

import os
import numpy as np
import pickle

import pickle

def read_cifar_batch(batch_path):
    with open(batch_path, 'rb') as f:
        batch = pickle.load(f, encoding='bytes')
        
        data = np.array(batch.get(b'data'))
        labels = np.array(batch.get(b'labels'))

    return data, labels

def read_cifar (batch_dir) :
    data_batches = []
    label_batches = []
    
    for i in range(1,2) :
        batch_filename = f'data_batch_{i}'
        batch_path = os.path.join(batch_dir, batch_filename)
        data, labels = read_cifar_batch(batch_path)
        data_batches.append(data)
        label_batches.append(labels)
        
        # test_batch_filename = 'test_batch'
        # test_batch_path = os.path.join(batch_dir, test_batch_filename)
        # data_test, labels_test = read_cifar_batch(test_batch_path)
        # data_batches.append(data_test)
        # label_batches.append(labels_test)
        
        data = np.concatenate(data_batches, axis=0)
        labels = np.concatenate(label_batches, axis=0)
 
    return data, labels

# def read_cifar(directory_path):
#     batches = os.listdir(directory_path)
#     data=None
#     labels=None

#     for batch in batches:
#         batch_path = os.path.join(directory_path, batch)
#         if not batch_path.endswith(".meta"):
#             data_batch,labels_batch=read_cifar_batch(batch_path)
#             if data is None:
#                 data=data_batch
#                 labels=labels_batch
#             else:
#                 data=np.concatenate((data,data_batch))
#                 labels=np.concatenate((labels,labels_batch))
#     return(data, labels)

def split_dataset(data, labels, split) : 
    
    number_total = data.shape[0]
    number_train = int(number_total * split)
    indices = np.arange(number_total)
    np.random.shuffle(indices)
    indices_train = indices[:number_train]
    indices_test = indices[number_train:]
    data_train = data[indices_train]
    labels_train = labels[indices_train]
    data_test = data[indices_test]
    labels_test = labels[indices_test]
    
    return(data_train, labels_train, data_test, labels_test)

if __name__ == "__main__":
    file = "./data/cifar-10-python/"
    data, labels = read_cifar(file)
    data_train, labels_train, data_test, labels_test = split_dataset(data, labels, 0.8)