Search code examples
pythonnumpyoptimizationnumba

Optimize this Python function - apply a linear transformation based on parity of index


I have this simple Python function:

import numpy as np

def fast_transform(img, offset, factor):
    rep = (img.shape[0]//2, img.shape[1]//2)
    out = (img.astype(np.float32) - np.tile(offset, rep)) * np.tile(factor, rep)
    return out

The function gets an image (as a NXM numpy ndarray) and two 2x2 arrays (offset and factor). It then calculates a basic linear transformation on every pixel in the image based on it's parity in each dimension: out[i,j] = (out[i,j] - offset[i%2,j%2]) * factor[i%2,j%2]

As you can see I used np.tile to try and speed up the function but this isn't fast enough for my needs (and I think the creation of the dummy np.tile arrays makes it sub-optimal). I tried to use numba but it doesn't support np.tile yet.

Can you help me optimize this function as much as possible? I am sure there is some simple way to do it I am missing.


Solution

  • If you're willing to use another library, you can use JAX to make your numpy function ~7x faster (though if your arrays have different shapes, this may not be ideal as JAX recompiles the function for different shapes):

    from jax import config
    config.update("jax_enable_x64", True)
    import jax.numpy as jnp
    import jax
    
    @jax.jit
    def fast_transform_jax(img, offset, factor):
        rep = (img.shape[0]//2, img.shape[1]//2)
        out = (img.astype(np.float32) - jnp.tile(offset, rep)) * jnp.tile(factor, rep)
        return out
    

    Slight modifications to the numba functions in @Andrej's answer so that they pass allclose with OPs function:

    @nb.njit
    def fast_transform_numba(img, offset, factor):
        img = img.astype(np.float32)
        out = np.empty(img.shape, dtype=np.float64)
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
        return out
    
    @nb.njit(parallel=True)
    def fast_transform_numba_parallel(img, offset, factor):
        img = img.astype(np.float32)
        out = np.empty(img.shape, dtype=np.float64)
        for i in nb.prange(img.shape[0]):
            for j in nb.prange(img.shape[1]):
                out[i, j] = (img[i, j] - offset[i % 2, j % 2]) * factor[i % 2, j % 2]
        return out
    

    Timings:

    rng = np.random.default_rng()
    
    N, M = 1000, 1000
    img = rng.random((N, M)) * 50
    offset = rng.random((2, 2)) * 40
    factor = rng.random((2, 2)) * 30
    
    assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba(img, offset, factor))
    assert np.allclose(fast_transform(img, offset, factor), fast_transform_numba_parallel(img, offset, factor))
    assert np.allclose(fast_transform(img, offset, factor), fast_transform_jax(img, offset, factor))
    
    %timeit fast_transform(img, offset, factor)
    %timeit fast_transform_numba(img, offset, factor)
    %timeit fast_transform_numba_parallel(img, offset, factor)
    %timeit fast_transform_jax(img, offset, factor).block_until_ready()
    

    Output:

    3.59 ms ± 332 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    1.39 ms ± 37.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    871 µs ± 47.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
    521 µs ± 36.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)