I have a simple function that takes an jax Array as input, searches for the first occurrence of 1, and replaces it with another jax Array (specified as a second input):
rules_int = [
jnp.array([0,0]),
jnp.array([1,1,1]),
]
# Even with the same size of inputs, the sizes of outputs can be different
def replace_first_one(arr, action):
index = jnp.where(arr == 1)[0]
if index.size == 0:
return arr
index = index[0]
new_arr = jnp.concatenate([arr[:index], rules_int[action], arr[index+1:]])
return new_arr
replace_first_one(jnp.array([1]), 0)
# result is Array([0, 0], dtype=int32)
But when I use vmap
a get an exception:
batch_arr = jnp.array([
jnp.array([1, 4, 5, 1]),
jnp.array([6, 1, 8, 1])
])
batch_actions = jnp.array([0, 1]) # Corresponding actions for each array
# Vectorize the function
vectorized_replace_first_one = vmap(replace_first_one, in_axes=(0, 0))
result = vectorized_replace_first_one(batch_arr, batch_actions)
index = jnp.where(arr == 1)[0] The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations. This BatchTracer with object id 140260750414512 was created on line:
I read on JAX docs:
JAX code used within transforms like jax.jit, jax.vmap, jax.grad, etc. requires all output arrays and intermediate arrays to have static shape: that is, the shape cannot depend on values within other arrays.
Please suggest how to make it work?
Ideally, these rules should be applied recursively until there are no rules to apply. (string rewriting system)
As written, it is impossible to do this with vmap
because the output of your function has a shape that depends on the value of action
, and so the output would have to be a ragged array, which JAX does not support (see JAX Sharp Bits: Dynamic Shapes).
To make the function compatible with vmap
, you'll have to adjust it so that it has static shape semantics: in particular, every entry of rules_int
must have the same length, and you cannot return arr
alone in cases where arr
doesn't have any 1
entries. Making these changes and adjusting the logic to avoid dynamically-shaped intermediates, you could write something like this:
import jax
rules_int = jnp.array([
[0,0],
[1,1],
])
def replace_first_one(arr, action):
index = jnp.where(arr == 1, size=1)[0][0]
arr_to_insert = rules_int[action]
output_size = len(arr) - 1 + len(arr_to_insert)
new_arr = jnp.where(jnp.arange(output_size) < index,
jnp.concatenate([arr[:-1], arr_to_insert]),
jnp.concatenate([arr_to_insert, arr[1:]]))
return jax.lax.dynamic_update_slice(new_arr, arr_to_insert, (index,))
replace_first_one(jnp.array([1]), 0)
# Array([0, 0], dtype=int32)
batch_arr = jnp.array([
jnp.array([1, 4, 5, 1]),
jnp.array([6, 1, 8, 1])
])
batch_actions = jnp.array([0, 1])
vectorized_replace_first_one = vmap(replace_first_one, in_axes=(0, 0))
vectorized_replace_first_one(batch_arr, batch_actions)
# Array([[0, 0, 4, 5, 1],
# [6, 1, 1, 8, 1]], dtype=int32)
If adjusting the semantics of your function in this way to avoid dynamic shapes is not viable given your use-case, then your use-case is unfortunately not compatible with vmap
or other JAX transformations.