Search code examples

Shuffling and importing few rows of a saved numpy file

I have 2 saved .npy files:

X_train - (18873, 224, 224, 3) - 21.2GB
Y_train - (18873,) - 148KB

X_train is cats and dogs images (cats being in 1st half and dogs in 2nd half, unshuffled) and is mapped with Y_train as 0 and 1. Thus Y_train is [1,1,1,1,1,1,.........,0,0,0,0,0,0].

I want to import randomly say, 256 images (both cats and dogs images in nearly 50-50%) in X and its mapping in Y. Since the data is large, I cannot import X_train in my RAM.

Thus I have tried (1st approach):

import numpy as np
X_train = np.load('Processed/X_train.npy', mmap_mode='r')
X = np.random.shuffle(X_train)
X = X[:256, :, :, :]
Y_train = np.load('Processed/Y_train.npy', mmap_mode='r')
Y = np.random.shuffle(Y_train)
Y = Y[:256]

This gives the following error:

ValueError                                Traceback (most recent call last)
<ipython-input-68-8b2a13921b8d> in <module>
      2 np.random.seed(666555)
      3 X_train = np.load('Processed/X_train.npy', mmap_mode='r')
----> 4 X = np.random.shuffle(X_train)
      5 X = X[:256, :, :, :]
      6 Y_train = np.load('Processed/Y_train.npy', mmap_mode='r')

mtrand.pyx in numpy.random.mtrand.RandomState.shuffle()

mtrand.pyx in numpy.random.mtrand.RandomState.shuffle()

ValueError: assignment destination is read-only

I have also tried (2nd approach):

import numpy as np
X = np.memmap('Processed/X_train.npy', 'float64', shape = (256, 224, 224, 3), mode = 'c')
Y = np.memmap('Processed/Y_train.npy', 'float64', shape = (256), mode = 'c')
X = np.random.shuffle(X)
Y = np.random.shuffle(Y)

This outputs:


In 2nd approach, I will get only cats images as np.memmap will collect only 1st 256 images. Then shuffling will be of no use.

Please tell me how to do this with any method.


  • your shuffelling procedure is not correct. following this strategy you are also shuffling your X in a different way from Y (there is no more match between X and Y after shuffle). here a demonstrative example:

    xxx = np.asarray([1,2,3,4,5,6,7,8,9])
    yyy = np.asarray([1,2,3,4,5,6,7,8,9])
    print((yyy == xxx).all()) # False

    here the correct procedure:

    xxx = np.asarray([1,2,3,4,5,6,7,8,9])
    yyy = np.asarray([1,2,3,4,5,6,7,8,9])
    idx = np.arange(0,len(xxx))
    print((yyy[idx] == xxx[idx]).all()) # True

    in this way you also override the None problem