Search code examples
pythonnumpyjax

Selecting all elements of subsets if at least one element is selected (JAX)


I have the following problem, for which I cannot manage to write a solution in JAX that is jittable and efficient.

I have a set of elements. Some of these elements are included (on the basis of a condition, which is not important now). The included elements are denoted by 1, the not-included elements by 0. For example, the array arr = jnp.array([1, 0, 0, 0, 0, 0]) indicates that I have 6 elements, the first of them included based on my condition.

These elements are grouped into subsets. I have a second array that indicates where each subset starts in the first array arr. For example, the array subsets = jnp.array([0, 2]) indicates that the first subset starts at position 0 and the second subset starts at position 2.

Now, if one element is included based on arr, I would like to include all the elements in the same subset. In this example, the output should then be [1, 1, 0, 0, 0, 0].

I have tried with a jax.lax.fori_loop, but it is slow.

@jax.jit                                                                        
def select_subsets(arr, subsets):                 
    new_arr = arr.copy()            
    n_resid = subsets.shape[0]    
    indices = jnp.arange(arr.shape[0])       
                                                                                    
    def func(i, new_arr):    
        start = subsets[i]    
        stop = subsets[i+1]    
        arr_sliced = jnp.where((indices >= start) & (indices < stop), arr, 0.0)    
        sum_ = jnp.sum(arr_sliced)    
        new_arr = jnp.where(sum_ > 0.5, jnp.where((indices >= start) & (indices < stop), 1, new_arr), new_arr)    
        return new_arr                          
                                                 
    new_arr = jax.lax.fori_loop(0, n_resid-1, func, new_arr)    
    return new_arr

this function works if I use a subsets with the last element equal to the number of elements in arr, subsets = jnp.array([0, 2, 6]).

I then thought about writing a vectorized version (using jax.numpy operations), but I cannot manage to do it.

Is there a JAX guru that can help me with this?

Thanks a lot!


Solution

  • Here's a vectorized version. It instantiates a mask with shape len(subsets) x len(arr), which might be undesirable depending on how big those values are.

               
    @jax.jit                                                                             
    def vectorized_select_subsets(arr, subsets):                         
        l, = arr.shape               
                                   
        indices = jnp.arange(l)[None, :]
                    
        # Broadcast to mask of shape (n_subsets, input_length)
        subset_masks = (
            (indices >= subsets[:-1, None])
            & (indices < subsets[1:, None])        
        )                                                           
                                                        
        # Shape (n_subsets,) array indicating whether each subset is included
        include_subset = jnp.any(subset_masks & arr[None, :], axis=1)          
                                                              
        # Reduce down columns 
        result = jnp.any(subset_masks & include_subset[:, None], axis=0).astype(jnp.int32)   
        return result
    

    I timed this against the loop-based version on an array with length 512 and 32 subsets:

    Loop: 6254.647 it/s
    Vectorized: 37940.335 it/s