Search code examples
python-3.xmachine-learningcomputer-visionbatch-processing

Reading Cifar10 dataset in batches


i am trying to read the CIFAR10 datasets, given in batches from https://www.cs.toronto.edu/~kriz/cifar.html>. i am trying to put it in a data frame using pickle and read 'data' part of it. But i am getting this error .

KeyError                                  Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')

<ipython-input-23-04002b89d842> in unpickle(file)
      3     fo = open(file, 'rb')
      4     dict = pickle.load(fo, encoding ='bytes')
----> 5     X = dict['data']
      6     fo.close()
      7     return dict

KeyError: 'data'.

i am using ipython and here is my code :

def unpickle(file):

 fo = open(file, 'rb')
 dict = pickle.load(fo, encoding ='bytes')
 X = dict['data']
 fo.close()
 return dict

unpickle('datasets/cifar-10-batches-py/test_batch')

Solution

  • you can read cifar 10 datasets by the code given below only make sure that you are giving write directory where the batches are placed

    import tensorflow as tf
    import pandas as pd
    import numpy as np
    import math
    import timeit
    import matplotlib.pyplot as plt
    from six.moves import cPickle as pickle
    import os
    import platform
    from subprocess import check_output
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
    %matplotlib inline
    
    
    img_rows, img_cols = 32, 32
    input_shape = (img_rows, img_cols, 3)
    def load_pickle(f):
        version = platform.python_version_tuple()
        if version[0] == '2':
            return  pickle.load(f)
        elif version[0] == '3':
            return  pickle.load(f, encoding='latin1')
        raise ValueError("invalid python version: {}".format(version))
    
    def load_CIFAR_batch(filename):
        """ load single batch of cifar """
        with open(filename, 'rb') as f:
            datadict = load_pickle(f)
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000,3072)
            Y = np.array(Y)
            return X, Y
    
    def load_CIFAR10(ROOT):
        """ load all of cifar """
        xs = []
        ys = []
        for b in range(1,6):
            f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)
            ys.append(Y)
        Xtr = np.concatenate(xs)
        Ytr = np.concatenate(ys)
        del X, Y
        Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
        return Xtr, Ytr, Xte, Yte
    def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
        # Load the raw CIFAR-10 data
        cifar10_dir = '../input/cifar-10-batches-py/'
        X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    
        # Subsample the data
        mask = range(num_training, num_training + num_validation)
        X_val = X_train[mask]
        y_val = y_train[mask]
        mask = range(num_training)
        X_train = X_train[mask]
        y_train = y_train[mask]
        mask = range(num_test)
        X_test = X_test[mask]
        y_test = y_test[mask]
    
        x_train = X_train.astype('float32')
        x_test = X_test.astype('float32')
    
        x_train /= 255
        x_test /= 255
    
        return x_train, y_train, X_val, y_val, x_test, y_test
    
    
    # Invoke the above function to get our data.
    x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()
    
    
    print('Train data shape: ', x_train.shape)
    print('Train labels shape: ', y_train.shape)
    print('Validation data shape: ', x_val.shape)
    print('Validation labels shape: ', y_val.shape)
    print('Test data shape: ', x_test.shape)
    print('Test labels shape: ', y_test.shape)