I wanted to perform a scan with a dynamic number of iterations. To accomplish that, I want to recompile the function each time when iters_to_do
changes.
To avoid a huge slowdown, I'll be using a recompilation_cache
but that's beside the point.
However, when I mark the argument in @partial(jax.jit)
I'm still obtaining a concretization error:
@partial(jax.jit, static_argnums=(3))
def iterate_for_steps(self,
interim_thought: Array,
mask: Array,
iters_to_do: int,
input_arr: Array,
key: PRNGKeyArray) -> Array:
# These are constants
input_arr = input_arr.astype(jnp.bfloat16)
interim_thought = interim_thought.astype(jnp.bfloat16)
def body_fun(i: int, thought: Array) -> Array:
latent = jnp.concatenate([thought, input_arr], axis=-1).astype(jnp.bfloat16)
latent = self.main_block(latent, input_arr, mask, key).astype(jnp.bfloat16)
latent = jax.vmap(self.post_ln)(latent).astype(jnp.bfloat16) # LN to keep scales tidy
return latent
iters_to_do = iters_to_do.astype(int).item()
final_val = jax.lax.scan(body_fun, interim_thought, xs=None, length=iters_to_do)
return final_val
I've tried marking multiple arguments with @partial
but to no avail.
I'm not sure how to approach debugging this - with a python debugger, I'm getting no help apart from the fact that its definitely a tracer.
from functools import partial
import jax
import jax.numpy as jnp
init = jnp.ones((5,))
iterations = jnp.array([1, 2, 3])
@partial(jax.jit, static_argnums=(0,))
def iterate_for_steps(iters: int):
def body_fun(carry):
return carry * 2
iters = iters.astype(int)
output = jax.lax.scan(body_fun, init, xs=None, length=iters)
return output
print(jax.vmap(iterate_for_steps)(iterations))
One can use equinox's (internal as of right now) while_loop
implementation which would also be able to handle a dynamic amount of iterations with checkpointing to reduce memory usage.
Note that this can be used as a drop-in replacement to jax's native while_loop
. One can also use equinox's eqx.internal.scan
if they wish to leverage similar checkpointing with scan
.