Search code examples
pythonimagememorydeep-learningpython-xarray

Patch creation methods for deep learning on very big data with relatively low amounts of memory (32GB)


I am trying to train a deep-learning semantic segmentation model for satellite imagery. In doing so, I've created a test run of the data on a small AOI with patchify and rasterio without any issues. However, I am now trying to expand this to include a lot more patches to train the model on and have increased my AOI in an effort to do so. For context, previously I'd have a ndarray of approximately 41848x14555x9 (x, y, n_bands). Now I'm looking to increase this to 84632x37000x9 (x, y, n_bands).

Unfortunately, numpy fails to even attempt to load the array into memory using rasterio due the large size not being feasible with the memory I have available, array is 128GB and my RAM is 32GB. Error message is below:

numpy.core._exceptions._ArrayMemoryError: Unable to allocate 120. GiB for an array with shape (9, 42170, 84632) and data type float32

I've attempted to utilize a combination rioxarray/ xarray and np.memmap due to the lazy arrays that are available - however populating the memmap is incredibly slow given I am having to iterate through the bands and an axis to allow for the values to fit into the amount of memory I have - i.e:

image_io = rioxarray.open_rasterio(/path/to/image_stack.tif)
raster = np.memmap(memmap_name, dtype=np.float32, mode='w+', shape=(image_io.shape[0], image_io.shape[1], image_io.shape[2]))
for i in range(len(image_io.band) + 1):
    for j in range(len(image_io.x)  + 1):
        raster[j, :, i] = image_io[j, :, i].values

I think the most important thing I should be asking is, is this even possible with the hardware resources I have available to me?

If it is, is there a better way to do this than with the approach I've listed above?

I'm not set on using patchify, but it seems to be the go to library used to generate smaller image tiles. Thanks in advance for any advice!


Solution

  • You might find using Xarray + Dask + Xbatcher to be a productive alternative here. Pseudo code below:

    # open tiff as a xarray.DataArray backed by a lazy dask array
    # you can tune the chunk size to the size of your patch
    chunks= {'x': xc, 'y': yc, 'n_bands': bc} 
    da = rioxarray.open_rasterio('/path/to/image_stack.tif', chunks=chunks)
    
    # create the xbatcher.BatchGenerator, this will let you iterate through your
    # DataArray in smaller-than-memory batches.
    bgen = xbatcher.BatchGenerator(da, {'x': xc, 'y': yc, 'n_bands': bc})
    for patch in bgen:
        # handle patch