Search code examples
pythonnumpyjax

JAX: JIT compatible sparse matrix slicing


I have a boolean sparse matrix that I represent with row indices and column indices of True values.

import numpy as np
import jax
from jax import numpy as jnp
N = 10000
M = 1000
X = np.random.randint(0, 100, size=(N, M)) == 0  # data setup
rows, cols = np.where(X == True)
rows = jax.device_put(rows)
cols = jax.device_put(cols)

I want to get a column slice of the matrix like X[:, 3], but just from rows indices and column indices.

I managed to do that by using jnp.isin like below, but the problem is that this is not JIT compatible because of the data-dependent shaped array rows[cols == m].

def not_jit_compatible_slice(rows, cols, m):
  return jnp.isin(jnp.arange(N), rows[cols == m])

I could make it JIT compatible by using jnp.where in the three-argument form, but this operation is much slower than the previous one.

def jit_compatible_but_slow_slice(rows, cols, m):
  return jnp.isin(jnp.arange(N), jnp.where(cols == m, rows, -1))

Is there any fast and JIT compatible solution to acheive the same output?


Solution

  • You can do a bit better than the first answer by using the mode argument of set() to drop out-of-bound indices, eliminating the final slice:

    out = jnp.zeros(N, bool).at[jnp.where(cols==3, rows, N)].set(True, mode='drop')