Search code examples
pythonnumpyvectorizationcumulative-sum

How to vectorise triple for looped cumulative sum


I want to vectorise the triple sum

\sum_{i=1}^I\sum_{j=1}^J\sum_{m=1}^J a_{ijm}

such that I end up with a matrix

A \in \mathbb{R}^{I \times J}

where A_{kl} = \sum_{i=1}^k\sum_{j=1}^l\sum_{m=1}^l a_{ijm} for k = 1,...,I and l = 1, ...,J

carrying forward the sums to avoid pointless recomputation.

I currently use this code: np.cumsum(np.cumsum(np.cumsum(a, axis = 0), axis = 1), axis = 2).diagonal(axis1 = 1, axis2 = 2) but it is inefficient as it does lots of extra work and extracts the correct matrix at the end with the diagonal method. I can't think of how to make this faster.


Solution

  • You can use Numba so to produce a very fast implementation. Here is the code:

    import numba as nb
    import numpy as np
    
    @nb.njit('(float64[:,:,::1],)', parallel=True)
    def compute(arr):
        ni, nj, nk = arr.shape
        assert nj == nk
        result = np.empty((ni, nj))
        # Parallel cumsum along the axis 1 and 2 + extraction of the diagonal
        for i in nb.prange(ni):
            tmp = np.zeros(nk)
            for j in range(nj):
                for k in range(nk):
                    tmp[k] += arr[i, j, k]
                result[i, j] = np.sum(tmp[:j+1])
        # Cumsum along the axis 0
        for i in range(1, ni):
            for k in range(nk):
                result[i, k] += result[i-1, k]
        return result
    
    result = compute(a)
    

    Here are performance results on my 6-core i5-9600KF with a 100x100x100 float64 input array:

    Initial code:      12.7 ms
    Chryophylaxs v1:    7.1 ms
    Chryophylaxs v2:    5.5 ms
    Numba:              0.2 ms
    

    This implementation is significantly faster than all others. It is about 64 times faster than the initial implementation. It is also actually optimal on my machine since it completely saturate the bandwidth of my RAM only for reading the input array (which is mandatory). Note that it is better not to use multiple threads for very small arrays.

    Note that this code also use far less memory as it only need 8 * nk * num_threads bytes of temporary storage as opposed to 16 * ni * nj * nk bytes for the initial solution.