Search code examples
pythonmemorypytorchhdf5data-processing

Most efficient way to use a large data set for PyTorch?


Perhaps this question has been asked before, but I'm having trouble finding relevant info for my situation.

I'm using PyTorch to create a CNN for regression with image data. I don't have a formal, academic programming background, so many of my approaches are ad-hoc and just terribly inefficient. May times I can go back through my code and clean things up later because the inefficiency is not so drastic that performance is significantly affected. However, in this case, my method for using the image data takes a long time, uses a lot of memory, and it is done every time I want to test a change in the model.

What I've done is essentially loaded the image data into numpy arrays, saved those arrays in an .npy file, and then when I want to use said data for the model I import all of the data in that file. I don't think the data set is really THAT large, as it is comprised of 5000, 3 color channel images of size 64x64. Yet my memory usage shoots up to 70%-80% (out of 16gb) when it is being loaded, and it takes 20-30 seconds to load in every time.

My guess is that I'm being dumb about the way I'm loading it in, but frankly I'm not sure what the standard is. Should I, in some way, put the image data somewhere before I need it, or should the data be loaded directly from the image files? And in either case, what is the best, most efficient way to do that, independent of file structure?

I would really appreciate any help on this.


Solution

  • Here is a concrete example to demonstrate what I meant. This assumes that you've already dumped the images into an hdf5 file (train_images.hdf5) using h5py.

    import h5py
    hf = h5py.File('train_images.hdf5', 'r')
    
    group_key = list(hf.keys())[0]
    ds = hf[group_key]
    
    # load only one example
    x = ds[0]
    
    # load a subset, slice (n examples) 
    arr = ds[:n]
    
    # should load the whole dataset into memory.
    # this should be avoided
    arr = ds[:]
    

    In simple terms, ds can now be used as an iterator which gives images on the fly (i.e. it doesn't load anything in memory). This should make the whole run time blazing fast.

    for idx, img in enumerate(ds):
       # do something with `img`