Skip to content
Snippets Groups Projects
Select Git revision
  • 96a3e90b8ff74caa7a4c972adb8cc82f24ddeec6
  • main default protected
  • adam
  • thomas
4 results

chess.js

Blame
  • read_cifar.py 1.57 KiB
    import numpy as np
    import pickle
    
    def read_cifar_batch(file):
        with open(file, 'rb') as fo:
            dict = pickle.load(fo, encoding='bytes')
            data = np.array(dict[b'data']).astype('float32')
            labels = np.array(dict[b'labels']).astype('int64')
        return data, labels
    
    #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
    
    #print(vect1)
    
    def read_cifar(directory):
        data = []
        labels = []
    
        for i in range(1,6):
            data_v, labels_v = read_cifar_batch(f'{directory}/data_batch_{i}')
            data.append(data_v)
            labels.append(labels_v)
    
        data_v, labels_v = read_cifar_batch(f'{directory}/test_batch')
        data.append(data_v)
        labels.append(labels_v)
    
        data = np.concatenate(data, axis = 0)
        labels = np.concatenate(labels, axis = 0)
    
        return(data, labels)
    
    def split_dataset(data, labels, split):
    
        data_size = data.shape[0]
        train_size = int(data_size * split)
        indices = np.arange(data_size)
        np.random.shuffle(indices)
    
        indices_train = indices[:train_size]
        indices_test = indices[train_size:]
        data_train = data[indices_train]
        labels_train = labels[indices_train]
        data_test = data[indices_test]
        labels_test = labels[indices_test]
        
        return(data_train, labels_train, data_test, labels_test)
    
        
    
    
    if __name__ == "__main__":
        #vect1= read_cifar_batch("data/cifar-10-batches-py/data_batch_1")
        #print(vect1)
        #vect2= read_cifar("data/cifar-10-batches-py")
        #print(vect2)
    
        pair = read_cifar("data/cifar-10-batches-py")
    
        vect3= split_dataset(pair[0], pair[1], 0.6)
        print(vect3)