Jax lax scan operates on a function that takes two arguments, a carry and a sequence of inputs. I am wondering how scan should be called if some inputs don't change across iterations of the scan. Naively, I could create a sequence of identical inputs, but this seems wasteful/redundant and more importantly, this isn't always possible, as scan can only scan over arrays. For example, one of the inputs I want to pass to my function is a train state (e.g. from flax.training import train_state) that contains my model and its parameters, which cannot be put into array. As I say in the title, these inputs may also change each time I call scan (e.g. the model parameters will change).
Any ideas on how best to do this?
Thanks.
In general, you have three possible approaches for this:
carry
to carry it along to each step of the scanHere are three examples of a computation using these strategies:
import jax
import jax.numpy as jnp
a = jnp.arange(5)
b = 2
# Strategy 1: duplicate b across sequence
def f(carry, xs):
a, b = xs
result = a * b
return carry + result, result
b_seq = jnp.full_like(a, b)
total, cumulative = jax.lax.scan(f, 0, (a, b_seq))
print(total) # 20
print(cumulative) # [0 2 4 6 8]
# Strategy 2: put b in the carry
def f(carry, xs):
carry, b = carry
a = xs
result = a * b
return (carry + result, b), result
(total, _), cumulative = jax.lax.scan(f, (0, b), a)
print(total) # 20
print(cumulative) # [0 2 4 6 8]
# Strategy 3: close over b
from functools import partial
def f(carry, xs, b):
a = xs
result = a * b
return carry + result, result
total, cumulative = jax.lax.scan(partial(f, b=b), 0, a)
print(total) # 20
print(cumulative) # [0 2 4 6 8]
Which you use probably depends on the context of where you are using it, but I personally think closure (option 3) is probably the cleanest approach.