I have the following problem, for which I cannot manage to write a solution in JAX that is jittable and efficient.
I have a set of elements. Some of these elements are included (on the basis of a condition, which is not important now). The included elements are denoted by 1, the not-included elements by 0. For example, the array arr = jnp.array([1, 0, 0, 0, 0, 0])
indicates that I have 6 elements, the first of them included based on my condition.
These elements are grouped into subsets. I have a second array that indicates where each subset starts in the first array arr
. For example, the array subsets = jnp.array([0, 2])
indicates that the first subset starts at position 0 and the second subset starts at position 2.
Now, if one element is included based on arr
, I would like to include all the elements in the same subset. In this example, the output should then be [1, 1, 0, 0, 0, 0]
.
I have tried with a jax.lax.fori_loop
, but it is slow.
@jax.jit
def select_subsets(arr, subsets):
new_arr = arr.copy()
n_resid = subsets.shape[0]
indices = jnp.arange(arr.shape[0])
def func(i, new_arr):
start = subsets[i]
stop = subsets[i+1]
arr_sliced = jnp.where((indices >= start) & (indices < stop), arr, 0.0)
sum_ = jnp.sum(arr_sliced)
new_arr = jnp.where(sum_ > 0.5, jnp.where((indices >= start) & (indices < stop), 1, new_arr), new_arr)
return new_arr
new_arr = jax.lax.fori_loop(0, n_resid-1, func, new_arr)
return new_arr
this function works if I use a subsets
with the last element equal to the number of elements in arr
, subsets = jnp.array([0, 2, 6])
.
I then thought about writing a vectorized version (using jax.numpy
operations), but I cannot manage to do it.
Is there a JAX guru that can help me with this?
Thanks a lot!
Here's a vectorized version. It instantiates a mask with shape len(subsets) x len(arr)
, which might be undesirable depending on how big those values are.
@jax.jit
def vectorized_select_subsets(arr, subsets):
l, = arr.shape
indices = jnp.arange(l)[None, :]
# Broadcast to mask of shape (n_subsets, input_length)
subset_masks = (
(indices >= subsets[:-1, None])
& (indices < subsets[1:, None])
)
# Shape (n_subsets,) array indicating whether each subset is included
include_subset = jnp.any(subset_masks & arr[None, :], axis=1)
# Reduce down columns
result = jnp.any(subset_masks & include_subset[:, None], axis=0).astype(jnp.int32)
return result
I timed this against the loop-based version on an array with length 512 and 32 subsets:
Loop: 6254.647 it/s
Vectorized: 37940.335 it/s