Search code examples
pythonnumpyraster

Is there a more Numpy-esque way of interpolating my rasters?


I've written a little function to interpolate time sequences of irregularly sampled raster images so that they are evenly spaced in time (below). It works fine but I just know from looking at it that I'm missing some shortcuts. I'm looking for a Numpy ninja to give me so pro tips on how to punch up my syntax, and maybe get a little performance boost too.

Cheers!

import numpy as np

def interp_rasters(rasters, chrons, sampleFactor = 1):
    nFrames = round(len(chrons) * sampleFactor)
    interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
    frames, rows, cols, channels = rasters.shape
    interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')
    outs = []
    for row in range(rows):
        for col in range(cols):
            for channel in range(channels):
                pixelSeries = rasters[:, row, col, channel]
                interpRasters[:, row, col, channel] = np.interp(
                    interpChrons,
                    chrons,
                    pixelSeries
                    )
    return interpRasters

Solution

  • As the y values to be looked up have to be 1d I can't see a way of not looping through the np.arrays. If the rasters and interpRasters arrays are reshaped as how in the function one loop can be used, without explicit indexing. This gave around a 10% speed improvement for my made up test data.

    import numpy as np
    
    frames = 10
    rows = 5
    cols = 10
    channels = 3
    
    np.random.seed(1234)
    
    rasters = np.random.randint(0,256, size=(frames, rows, cols, channels))
    chrons = np.random.randint(0, 256, size  = 10 )
    
    
    # The original function.
    def interp_rasters(rasters, chrons, sampleFactor = 1):
        nFrames = round(len(chrons) * sampleFactor)
        interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
        frames, rows, cols, channels = rasters.shape
        interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')
        outs = []
        for row in range(rows):
            for col in range(cols):
                for channel in range(channels):
                    pixelSeries = rasters[:, row, col, channel]
                    interpRasters[:, row, col, channel] = np.interp(
                        interpChrons,
                        chrons,
                        pixelSeries
                        )
        return interpRasters
    
    def interp_rasters2(rasters, chrons, sampleFactor = 1):
        nFrames = round(len(chrons) * sampleFactor)
        interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
        frames, rows, cols, channels = rasters.shape
        interpRasters = np.zeros((nFrames, rows, cols, channels), dtype = 'uint8')
    
        # Create reshaped arrays pointing to the same data 
        dat_in = rasters.reshape(frames, rows*cols*channels).T  
        # shape (r*c*c, frames)
    
        dat_out = interpRasters.reshape(nFrames, rows*cols*channels).T  
        # shape (r*c*c, frames)
    
        for pixelseries, row_out in zip(dat_in, dat_out):
            # Loop through all data in one loop.
            row_out[:] = np.interp( interpChrons, chrons, pixelseries )
        return interpRasters  
        # As dat_out and interpRasters share the same data return interpRasters
    
    print(np.isclose(interp_rasters(rasters, chrons), interp_rasters2(rasters, chrons)).all())
    # True  # The results are the same from the two functions.
    
    %timeit interp_rasters(rasters, chrons)
    # 568 µs ± 2.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    %timeit interp_rasters2(rasters, chrons)
    # 520 µs ± 239 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    

    Edit There's also np.apply_along_axis. This removes any explicit for loops, reduces the amount of code but is slower than the previous solutions.

    def interp_rasters3(rasters, chrons, sampleFactor = 1):
        nFrames = round(len(chrons) * sampleFactor)
        interpChrons = np.linspace(np.min(chrons), np.max(chrons), nFrames)
    
        def func( arr ):  # Define the function to apply along the axis
            return np.interp( interpChrons, chrons, arr )
    
        return np.apply_along_axis( func, 0, rasters ).astype( np.uint8 )
    
    print(np.isclose(interp_rasters(rasters, chrons), interp_rasters3(rasters, chrons)).all())
    # True
    

    I think I'd understand version 3 better than version 1 or 2 in 6 months time if speed isn't critical.

    HTH