Search code examples
pythonsparse-matrixjax

Indexing a BCOO in Jax


I came across another problem in my attempts to learn jax: I have a sparse BCOO array, and an array holding indices. I need to obtain all values at these indices in the BCOO array. It would be ideal if the returned array would be a sparse BCOO as well. Using the usual slicing syntax seems to not work. Is there a standard way to achieve this? e.g.

import jax.numpy as jnp
from jax.experimental import sparse

indices = jnp.array([
    1,1,0
])

full_array = jnp.array(
        [
            [
                [0,0,0],
                [2,2,2],
                [0,0,0],
                [0,0,0]
            ],
            [
                [1,1,1],
                [0,0,0],
                [0,0,0],
                [0,0,0]
            ],
            [
                [1,1,1],
                [0,0,0],
                [0,0,0],
                [0,0,0]
            ]
        ]
)
full_array[jnp.arange(3),indices]
# results in:
#    [2,2,2],
#    [0,0,0],
#    [1,1,1] 

sparse_array = sparse.bcoo_fromdense(full_array)

# Trying the same thing on a sparse array:
sparse_array[jnp.arange(3),indices]
# produces an NotImplementedError

Solution

  • [Edit: 2022-11-15] As of jax version 0.3.25, this kind of sparse indexing is directly supported in JAX:

    import jax
    print(jax.__version__)
    # 0.3.25
    
    sparse_array = sparse.bcoo_fromdense(full_array)
    result = sparse_array[jnp.arange(3),indices]
    
    print(result.todense())
    # [[2 2 2]
    #  [0 0 0]
    #  [1 1 1]]
    

    Original answer:

    Thanks for the question. Unfortunately, general indexing support has not been added yet to jax.experimental.sparse. The types of indexing operations currently supported are limited to static scalars and slices; for example:

    print(sparse_array[0].todense())
    # [[0 0 0]
    #  [2 2 2]
    #  [0 0 0]
    #  [0 0 0]]
    

    With this in mind, you may be able to build the operation you have in mind using concatenation. For example:

    result = sparse.sparsify(jnp.vstack)([
        sparse_array[0][1],  # only single indices supported currently
        sparse_array[1][1],
        sparse_array[2][0],
    ])
    print(result.todense())
    # [[2 2 2]
    #  [0 0 0]
    #  [1 1 1]]
    

    Admittedly only supporting static indices is not very convenient, but we hope to add more indexing support in the future.