Search code examples
pythonjax

What is the fastest way of selecting a subset of a JAX matrix?


Let's say I have a 2D matrix and I want to plot its values in a histogram. For that, I need to do something like:

list_1d = matrix_2d.reshape((-1,)).tolist()

And then use the list to plot the histogram. So far so good, it's just that there are items in the original matrix that I want to exclude. For simplicity, let's say I have a list like this:

exclude = [(2, 5), (3, 4), (6, 1)]

So, the list_1d should have all the items in the matrix without the items pointed to by the exclude (the items of exclude are row and column indices).

And BTW, the matrix_2d is a JAX array which means its content is in GPU.


Solution

  • One way to do this is to create a mask array that you use to select the desired subset of the array. The mask indexing operation returns a 1D copy of the selected data:

    import jax.numpy as jnp
    from jax import random
    matrix_2d = random.uniform(random.PRNGKey(0), (10, 10))
    exclude = [(2, 5), (3, 4), (6, 1)]
    
    ind = tuple(jnp.array(exclude).T)
    mask = jnp.ones_like(matrix_2d, dtype=bool).at[ind].set(False)
    
    list_1d = matrix_2d[mask].tolist()
    len(list_1d)
    # 97