Search code examples
pythonnumpycluster-computinggdal

Is there a faster method for iterating over a very big 2D numpy array than using np.where?


i have a huge 2D numpy array filled with integer values. I collect them from a .tif-image via gdal.GetRasterBand(). The pixel values of the image represent unique cluster-identification numbers. So all pixels inside one cluster have the same value. In my script i want to check if the clusters have more pixels than a specific threshold. If the clustersize is bigger than the threshold I want to keep the cluster and give them a pixel value 1. If a cluster has less pixel then the threshold, all pixels of this cluster should get the value 0.

My code so far works, but is very very slow. And because i want to vary the threshold, it takes like forever. I would really appreciate your help. Thank you.

# Import GeoTIFF via GDAL and convert to NumpyArray
data = gdal.Open(image)
raster = data.GetRasterBand(1)
raster = raster.ReadAsArray()

# Different thresholds for iteration
thresh = [0,10,25,50,100,1000,2000]

for threshold in thresh:
        clusteredRaster = np.array(raster.copy(), dtype = int)

        for clump in np.unique(clusteredRaster): # Unique ids of the clusters in image

            if clusteredRaster[np.where(clusteredRaster == clump)].size >= threshold: 
                clusteredRaster[np.where(clusteredRaster == clump)] = int(1)

            else:
                clusteredRaster[np.where(clusteredRaster == clump)] = int(0)
'''

[ClusterImage][1]

In the image you can see the cluster image. Each color stands vor a specific clusternumber. I want to delete the small ones (under a specific size) and just keep the big ones.

  [1]: https://i.sstatic.net/miEKg.png

Solution

  • I got a easy solution based on your helpful answers! The idea is to find the unique values and cluster sizes per threshold and instant fill in correct values, thus avoid a loop. It reduces the iteration time from initially 142 seconds per iteration to 0.52 seconds and reproduces the same results.

    data = gdal.Open(image)
    raster = data.GetRasterBand(1).ReadAsArray()
    
    thresh = [0, 10, 25, 50, 100, 1000, 2000]   
    for threshold in thresh:
        # Create new 0-raster with same dimensions as input raster
        clusteredRaster = np.zeros(raster.shape, dtype = uint8)
        
        # Get unique cluster IDs and count the size of the occurence
        clumps, counts = np.unique(raster, return_counts=True)
    
        # Get only the clumps which are bigger than the threshold
        biggerClumps = clumps[counts >= threshold]
    
        # fill in ones for the relevant cluster IDs
        clusteredRaster[np.isin(raster,biggerClumps)] = 1