Search code examples
pythonsetjax

multidimensional jax.isin()


i am trying to filter an array of triples. The criterion by which I want to filter is whether another array of triples contains at least one element with the same first and third element. E.g

import jax.numpy as jnp
array1 = jnp.array(
  [
    [0,1,2],
    [1,0,2],
    [0,3,3],
    [3,0,1],
    [0,1,1],
    [1,0,3],
  ]
)
array2 = jnp.array([[0,1,3],[0,3,2]])
# the mask to filter the first array1 should look like this:
jnp.array([True,False,True,False,False,False])

What would be a computationally efficient way to achieve this mask using jax? I am looking forward to your input.


Solution

  • You can do this by reducing over a broadcasted equality check:

    import jax.numpy as jnp
    array1 = jnp.array(
      [
        [0,1,2],
        [1,0,2],
        [0,3,3],
        [3,0,1],
        [0,1,1],
        [1,0,3],
      ]
    )
    array2 = jnp.array([[0,1,2],[0,3,2]])  # note adjustment to match first entry of array1
    
    mask = (array1[:, None] == array2[None, :]).all(-1).any(-1)
    print(mask)
    # [ True False False False False False]
    

    XLA doesn't have any binary search-like primitive, so the best approach in general is to generate the full equality matrix and reduce. If you're running the code on an accelerator like a GPU/TPU, this sort of vectorized operation is efficiently parallelized and so it will be computed quite efficiently in practice.