Search code examples
pythonmathsumvectorization

Fastest way to evaluate the sum of a function over a rectangular grid


Consider a function f(m,n). I want to sum this function over a rectangular grid for m in range(1,M) and n in range(1,N). What is the most efficient way to do this? Can I get an order of magnitude (or more) speed up over what I have tried so far?

I am using numpy.ogrid and vectorisation.

import numpy as np
from timeit import default_timer as timer

def f(m,n):
    return 1 / np.power(m**2 + n**2, 2)

def sum(M, N):
    m, n = np.ogrid[1:M, 1:N]
    return np.sum(f(m, n))

start = timer()
sum(10000,10000)
end = timer()
print(f"Run time is {(end - start) * 1000} ms")

Returns: Run time is 515.3399159999999 ms


Solution

  • you could use **2 or np.square instead of np.power(...,2) as python doesn't optimize the power of integers to simple multiplication.

    the fastest way to do it is to use numba and compile this down to machine code, and add multithreading while you are at at.

    import numpy as np
    from timeit import default_timer as timer
    
    def f(m, n):
        return 1 / np.square(m**2 + n**2, dtype=np.int64)
    
    def sum(M, N):
        m, n = np.ogrid[1:M, 1:N]
        return f(m, n).sum()
    
    from numba import njit, prange
    
    @njit("f8(i8,i8)",error_model="numpy", fastmath=True, parallel=True)
    def sum2(M, N):
        total_sum = 0.0
        for i in prange(1, M):
            for j in range(1, N):
                total_sum += 1 / (i**2 + j**2)**2
        return total_sum
    
    
    start = timer()
    res1 = sum(10000,10000)
    end = timer()
    print(f"Run time numpy is {(end - start) * 1000} ms")
    
    start = timer()
    res2 = sum2(10000,10000)
    end = timer()
    print(f"Run time numba is {(end - start) * 1000} ms")
    assert np.isclose(res1, res2)
    
    Run time numpy is 970.1803999999998 ms
    Run time numba is 44.137399999999886 ms
    

    as you can see numba can acheive the order of magnitude speedup.

    for anyone going to comment that a threaded sum is not safe, numba converts this sum into a reduction so it is safe.