I came across another problem in my attempts to learn jax: I have a sparse BCOO array, and an array holding indices. I need to obtain all values at these indices in the BCOO array. It would be ideal if the returned array would be a sparse BCOO as well. Using the usual slicing syntax seems to not work. Is there a standard way to achieve this? e.g.
import jax.numpy as jnp
from jax.experimental import sparse
indices = jnp.array([
1,1,0
])
full_array = jnp.array(
[
[
[0,0,0],
[2,2,2],
[0,0,0],
[0,0,0]
],
[
[1,1,1],
[0,0,0],
[0,0,0],
[0,0,0]
],
[
[1,1,1],
[0,0,0],
[0,0,0],
[0,0,0]
]
]
)
full_array[jnp.arange(3),indices]
# results in:
# [2,2,2],
# [0,0,0],
# [1,1,1]
sparse_array = sparse.bcoo_fromdense(full_array)
# Trying the same thing on a sparse array:
sparse_array[jnp.arange(3),indices]
# produces an NotImplementedError
[Edit: 2022-11-15] As of jax version 0.3.25, this kind of sparse indexing is directly supported in JAX:
import jax
print(jax.__version__)
# 0.3.25
sparse_array = sparse.bcoo_fromdense(full_array)
result = sparse_array[jnp.arange(3),indices]
print(result.todense())
# [[2 2 2]
# [0 0 0]
# [1 1 1]]
Original answer:
Thanks for the question. Unfortunately, general indexing support has not been added yet to jax.experimental.sparse
. The types of indexing operations currently supported are limited to static scalars and slices; for example:
print(sparse_array[0].todense())
# [[0 0 0]
# [2 2 2]
# [0 0 0]
# [0 0 0]]
With this in mind, you may be able to build the operation you have in mind using concatenation. For example:
result = sparse.sparsify(jnp.vstack)([
sparse_array[0][1], # only single indices supported currently
sparse_array[1][1],
sparse_array[2][0],
])
print(result.todense())
# [[2 2 2]
# [0 0 0]
# [1 1 1]]
Admittedly only supporting static indices is not very convenient, but we hope to add more indexing support in the future.