Search code examples

Simplest equivalent implementation of for use in JAX

I have a square numpy.ndarray and a numpy boolean mask of the same shape. I want to find the first element in each row of the array that is not masked.

My code currently relies on, which does exactly what I need. However, I now need to migrate my code to JAX, which has not implemented within jax.numpy.

What would be the simplest way to find the index of the first unmasked element in each row, calling only numpy functions that have been implemented in JAX (which exclude

The code I'm trying to reproduce is something like:

import numpy as np
my_array = np.random.rand(5,5)
mask = (my_array < 0.5)
my_masked_array =, mask=mask), axis=1)[0]

I'm sure there are many ways to do this, but I'm looking for the least unwieldy way.


  • Here's a JAX implementation of nonmasked_edges, which takes a boolean mask and returns the same indices returned by the function:

    import jax.numpy as jnp
    def notmasked_edges(mask, axis=None):
      mask = jnp.asarray(mask)
      assert mask.dtype == bool
      if axis is None:
        mask = mask.ravel()
        axis = 0
      shape = list(mask.shape)
      del shape[axis]
      alltrue = mask.all(axis=axis).ravel()
      indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
      indices = [jnp.ravel(ind)[~alltrue] for ind in indices]
      first = indices.copy()
      first.insert(axis, jnp.argmin(mask, axis=axis).ravel()[~alltrue])
      last = indices.copy()
      last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel()[~alltrue])
      return [tuple(first), tuple(last)]

    This will not be compatible with JIT, because the size of the output arrays depend on the values of the mask (rows which have no unmasked value are left out).

    If you want a JIT-compatible version, you can remove the [~alltrue] indexing, and the first/last index will be returned for rows that have no unmasked value:

    def notmasked_edges_v2(mask, axis=None):
      mask = jnp.asarray(mask)
      assert mask.dtype == bool
      if axis is None:
        mask = mask.ravel()
        axis = 0
      shape = list(mask.shape)
      del shape[axis]
      indices = jnp.meshgrid(*(jnp.arange(n) for n in shape), indexing='ij')
      indices = [jnp.ravel(ind) for ind in indices]
      first = indices.copy()
      first.insert(axis, jnp.argmin(mask, axis=axis).ravel())
      last = indices.copy()
      last.insert(axis, mask.shape[axis] - 1 - jnp.argmin(jnp.flip(mask, axis=axis), axis=axis).ravel())
      return [tuple(first), tuple(last)]

    Here's an example:

    import numpy as np
    mask = np.array([[True, False, False, True],
                     [False, False, True, True],
                     [True, True, True, True]])
    arr =, mask=mask)
    print(, axis=1))
    # [(array([0, 1]), array([1, 0])), (array([0, 1]), array([2, 1]))]
    print(notmasked_edges(mask, axis=1))
    # [(Array([0, 1], dtype=int32), Array([1, 0], dtype=int32)),
    #  (Array([0, 1], dtype=int32), Array([2, 1], dtype=int32))]
    print(notmasked_edges_v2(mask, axis=1))
    # [(Array([0, 1, 2], dtype=int32), Array([1, 0, 0], dtype=int32)),
    #  (Array([0, 1, 2], dtype=int32), Array([2, 1, 3], dtype=int32))]