Search code examples
pythonjax

How to implement the next function (the use of Dynamic Shapes) in JAX?


I have a simple function that takes an jax Array as input, searches for the first occurrence of 1, and replaces it with another jax Array (specified as a second input):

rules_int = [
    jnp.array([0,0]),
    jnp.array([1,1,1]),
]
# Even with the same size of inputs, the sizes of outputs can be different

def replace_first_one(arr, action):
    index = jnp.where(arr == 1)[0]
    if index.size == 0:
        return arr
    index = index[0]
    new_arr = jnp.concatenate([arr[:index], rules_int[action], arr[index+1:]])
    return new_arr

replace_first_one(jnp.array([1]), 0)
# result is Array([0, 0], dtype=int32)

But when I use vmap a get an exception:

batch_arr = jnp.array([
    jnp.array([1, 4, 5, 1]),
    jnp.array([6, 1, 8, 1])
])

batch_actions = jnp.array([0, 1])  # Corresponding actions for each array

# Vectorize the function
vectorized_replace_first_one = vmap(replace_first_one, in_axes=(0, 0))
result = vectorized_replace_first_one(batch_arr, batch_actions)

index = jnp.where(arr == 1)[0] The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. This BatchTracer with object id 140260750414512 was created on line:

I read on JAX docs:

JAX code used within transforms like jax.jit, jax.vmap, jax.grad, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.

Please suggest how to make it work?

Ideally, these rules should be applied recursively until there are no rules to apply. (string rewriting system)


Solution

  • As written, it is impossible to do this with vmap because the output of your function has a shape that depends on the value of action, and so the output would have to be a ragged array, which JAX does not support (see JAX Sharp Bits: Dynamic Shapes).

    To make the function compatible with vmap, you'll have to adjust it so that it has static shape semantics: in particular, every entry of rules_int must have the same length, and you cannot return arr alone in cases where arr doesn't have any 1 entries. Making these changes and adjusting the logic to avoid dynamically-shaped intermediates, you could write something like this:

    import jax
    
    rules_int = jnp.array([
        [0,0],
        [1,1],
    ])
    
    def replace_first_one(arr, action):
        index = jnp.where(arr == 1, size=1)[0][0]
        arr_to_insert = rules_int[action]
        output_size = len(arr) - 1 + len(arr_to_insert)
        new_arr = jnp.where(jnp.arange(output_size) < index,
                            jnp.concatenate([arr[:-1], arr_to_insert]),
                            jnp.concatenate([arr_to_insert, arr[1:]]))
        return jax.lax.dynamic_update_slice(new_arr, arr_to_insert, (index,))
    
    replace_first_one(jnp.array([1]), 0)
    # Array([0, 0], dtype=int32)
    
    batch_arr = jnp.array([
        jnp.array([1, 4, 5, 1]),
        jnp.array([6, 1, 8, 1])
    ])
    
    batch_actions = jnp.array([0, 1])
    
    vectorized_replace_first_one = vmap(replace_first_one, in_axes=(0, 0))
    vectorized_replace_first_one(batch_arr, batch_actions)
    # Array([[0, 0, 4, 5, 1],
    #        [6, 1, 1, 8, 1]], dtype=int32)
    

    If adjusting the semantics of your function in this way to avoid dynamic shapes is not viable given your use-case, then your use-case is unfortunately not compatible with vmap or other JAX transformations.