Search code examples
pythonnumpyjax

is there a more efficient equivalent of np.sum(np.cumprod(1 / (1 + y*x)))?


I have a 1D NumPy array, such as x:

x = np.array([0.05, 0.06, 0.06, 0.04])

I'm showing a small array, but in reality, x can be very large. To x, I want to perform the following calculation:

y = 1./12.
np.sum(np.cumprod(1 / (1 + y * x)))

Because x is very large, I want to do this more efficiently. I tried to use np.exp(np.cumsum(np.log(1 / (1 + y * x)))).sum() but this makes it slower. Is there a more efficient NumPy/JAX function?


Solution

  • You pretty much vectorized as much as possible.

    What you could spare here, if your were writing in C, is the numerous (but implicit, that is done thanks to vectorization inside numpy's code, that is in C) for loops involved. Since what your code means is that you

    • Firstly iterate all elements of x to multiply them by y. In pure python that would be compound list [y*a for a in x]
    • Secondly iterate again all elements of former result, to add 1 to all of them. So, [1+z for z in [y*a for a in x]]
    • Thirdly, iterate again all elements of the former result to invert all of them. So [1/u for u in [1+z for z in [y*a for a in x]]]
    • Fourthly, iterate again to compute cumulative product. So p=1; [p:=v*p for v in [1/u for u in [1+z for z in [y*a for a in x]]]]
    • Fifthly, iterate again to compute sum of former

    So, sure, all those for loops are in C, so very fast. But there are many (non-nested) of them. And each of them doesn't do much. So time spend in the iteration itself the for(int i=0; i<arr_len; i++) that occurs somewhere in numpy's C code, is not that negligible before the content of that iteration (the result[i] = y*x[i] that is repeated by this loop in numpy's C code).

    If you were writing this in pure python

    def cumsumprod(x,y):
        z=[y*a for a in x]
        u=[1+a for a in z]
        v=[1/a for a in u]
        p=1; w=[p:=p*a for a in v]
        s=0
            for a in w: s+=a
        return s
    

    That would be way less efficient than this other pure python implementation

    def ff(x,y):
        s=0
        p=1
        for a in x:
            p/=(1+y*a)
            s+=p
        return s
    

    Same computation. But one for loop instead of 5.

    To be quantitative in what I say, on your example, in microseconds, your code takes 9.8 μs on my machine. My 1st python code 3.6 μs. And my 2nd, 1.9 μs.

    And yes, with that small data, pure python codes are both faster than numpy. If array is size 1000 instead, those timings become 17.7, 464 and 288. But point is, my second code is faster than my first, unsurprisingly. And your numpy code is the equivalent of my first code, but in C.

    And that is even an understatement, since I just use the example of for loops. It is not the only redundant thing that numpy does. For example, it also allocates a new array for each intermediary operation.

    Not that you did anything wrong. You did exactly what you are supposed to do with numpy. Just that is what numpy does: if provides many vectorized operation that we can sequence, each of them being a for loop on all our data. And we pay the price of having several unnecessary for loops, in exchange of the reward of having them in C, when pure python would be way slower. That is pretty much the best you can have from numpy.

    If you want to have more, a way is numba. Numba allows you to write, otherwise naive, code, and yet have it fast, in C.

    Just add @jit before my previous pure python's code

    from numba import jit
    
    @jit(nopython=True)
    def ff(x,y):
        s=0
        p=1
        for a in x:
            p/=(1+y*a)
            s+=p
        return s
    

    And you get something that is both in C, and doesn't contain the unnecessary bu yet unavoidable operations that sequencing many numpy's vectorized operation does.

    Timings for this function is 0.25 μs for your list. So way better than 1.9 μs of the same in pure python, thanks to compilation.

    And for a size 1000 list, where numpy's beat pure python, timing is 3.4 μs. So not only way better than python 276 μs, but also better than numpy's 14.9 μs, thanks to the simplicity of the algorithm.

    So long story short: numba allows to write plain, simple, naive algorithm on numpy array, that are compiled.