Search code examples
pythonjax

jax.numpy.delete assume_unique_indices unexpected keyword argument


I can not seem to get the assume_unique_indices from jax.numpy working. According to the documentation here, the jnp.delete has a keyword argument "assume_unique_indices" that is supposed to make this function jit compatible when we are sure that the index array is an integer array and is guaranteed to contain unique entries.

Here is an minimum reproducible example

import jax

arr = jnp.array([1, 2, 3, 4, 5])
idx = jnp.array([0, 2, 4])

print(jax.__version__)

# Delete elements at indices idx
out = jax.numpy.delete(arr, idx, assume_unique_indices=True)

print(out) # [2 4]

The error message

0.4.8
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-12-bf0277118922> in <cell line: 9>()
      7 
      8 # Delete elements at indices idx
----> 9 out = jax.numpy.delete(arr, idx, assume_unique_indices=False)
     10 
     11 print(out) # [2 4]

TypeError: delete() got an unexpected keyword argument 'assume_unique_indices'

Deleting the assume_unique_indices made it work as expected.


Solution

  • assume_unique_indices was added in https://github.com/google/jax/pull/15671, after JAX version 0.4.8 was released. If you update to version 0.4.9 or newer, your code should work.