Search code examples
python-3.xjax

Jax scan with dynamic number of iterations


I wanted to perform a scan with a dynamic number of iterations. To accomplish that, I want to recompile the function each time when iters_to_do changes.

To avoid a huge slowdown, I'll be using a recompilation_cache but that's beside the point.

However, when I mark the argument in @partial(jax.jit) I'm still obtaining a concretization error:

@partial(jax.jit, static_argnums=(3))
def iterate_for_steps(self,
                        interim_thought: Array, 
                        mask: Array,
                        iters_to_do: int, 
                        input_arr: Array, 
                        key: PRNGKeyArray) -> Array:

    # These are constants
    input_arr = input_arr.astype(jnp.bfloat16)
    interim_thought = interim_thought.astype(jnp.bfloat16)
    
    def body_fun(i: int, thought: Array) -> Array:
        latent = jnp.concatenate([thought, input_arr], axis=-1).astype(jnp.bfloat16)
        latent = self.main_block(latent, input_arr, mask, key).astype(jnp.bfloat16)
        latent = jax.vmap(self.post_ln)(latent).astype(jnp.bfloat16)  # LN to keep scales tidy

        return latent
    
    iters_to_do = iters_to_do.astype(int).item()
    final_val = jax.lax.scan(body_fun, interim_thought, xs=None, length=iters_to_do)
    
    return final_val

Full traceback is here.

I've tried marking multiple arguments with @partial but to no avail.

I'm not sure how to approach debugging this - with a python debugger, I'm getting no help apart from the fact that its definitely a tracer.

MRE

from functools import partial
import jax
import jax.numpy as jnp

init = jnp.ones((5,))
iterations = jnp.array([1, 2, 3])

@partial(jax.jit, static_argnums=(0,))
def iterate_for_steps(iters: int):
    def body_fun(carry):
        return carry * 2
    
    iters = iters.astype(int)
    output = jax.lax.scan(body_fun, init, xs=None, length=iters)
    
    return output

print(jax.vmap(iterate_for_steps)(iterations))

Solution

  • One can use equinox's (internal as of right now) while_loop implementation which would also be able to handle a dynamic amount of iterations with checkpointing to reduce memory usage.

    Note that this can be used as a drop-in replacement to jax's native while_loop. One can also use equinox's eqx.internal.scan if they wish to leverage similar checkpointing with scan.