Skip to content
Snippets Groups Projects
Select Git revision
  • 5dbc3e3559712103f0b87b641bfaebecec0e6cee
  • master default protected
2 results

graph-parcours-dijkstra.py

Blame
  • Forked from Vuillemot Romain / INF-TC1
    Source project has a limited visibility.
    read_cifar.py 1.74 KiB
    import pickle
    import numpy as np
    from sklearn.model_selection import train_test_split 
    import random
    
    
    def unpickle(file):
        with open(file, 'rb') as fo:
            batch = pickle.load(fo, encoding='bytes')
        return batch
    
    
    def read_cifar_batch(path):
        batch = unpickle(path)
        data = batch[b'data']
        labels = batch[b'labels']
        return np.float32(data), np.int64(labels)
    
    
    def read_cifar(folder_path):
        
        # Get the test batch
        data, labels = read_cifar_batch("./data/cifar-10-batches-py/test_batch")
        
        # Concatenate with the 5 data batches
        for i in range(1,5):
            np.append(data, read_cifar_batch(folder_path + "/data_batch_" + str(i))[0])
            np.append(labels, read_cifar_batch(folder_path + "/data_batch_" + str(i))[1])
            
        return data, labels
    
    
    def split_dataset(data, labels, split):
        
        # Determination of an index to split the data
        index = int(split * len(data))
        
        # Split the data on the index
        tableau_combine = list(zip(data, labels))
        random.shuffle(tableau_combine)
        data_train, data_test, labels_train, labels_test = train_test_split(data, labels, test_size=1-split, random_state=1)
        # data_train, data_test = np.split(data, [index])
        # labels_train, labels_test = np.split(labels, [index])
        return data_train, labels_train, data_test, labels_test
        
    
    
    if __name__ == "__main__":
        
        # Extraction of the data from Cifar database
        data, labels = read_cifar("./data/cifar-10-batches-py")
        print(data)
        print(labels)
        
        # Formatting the data into training and testing sets
        split = 0.21
        data_train, labels_train, data_test, labels_test = split_dataset(data, labels, split)
        print(data_train)
        print(labels_train)
        print(data_test)
        print(labels_test)