Search code examples
pythonnumpysummaskjax

Mask a numpy array after a given value


I have two numpy arrays like :

a = [False, False, False, False, False, True, False, False]

b = [1, 2, 3, 4, 5, 6, 7, 8]

I need to sum over b, not the full array, but only until the elements with the equivalent index in a is True

In other words, I want to do 1+2+3+4+5=15 instead of 1+2+3+4+5+6+7+8=36

I need an efficient solution, I think I need to mask all elements from b that are after the first True in a and make them 0.

Side note: My code is in jax.numpy and not original numpy but I guess it doesn't really matter.


Solution

  • You can do a cumulated sum

    np.sum(b[np.cumsum(a)==0])