I wish to pass on the returned stacked values from the jax.lax.scan back to one of its arguments. Is it possible to do so? For example:
from jax import lax
def cumsum(res, el):
"""
- `res`: The result from the previous loop.
- `el`: The current array element.
"""
v, u = res
print(u)
v = v + el
return (v,u),v # ("carryover", "accumulated")
result_init = 0
result = []
(final,use), result = lax.scan(cumsum, (result_init,result), a)
In the above code, I want to extract the cumulated res values during the runtime and pass it back. Thus, I have passed the result as an argument in the lax function, but it always prints an empty list.
There's no built-in way to access the "current state" of the accumulated values in the course of a scan
operation: in particular, the current state will be a dynamically-shaped array (it will have size 0 in the first iteration, size 1 in the second, etc.) and scan
, like other JAX transformations and higher-order functions, requires static shapes.
But you could do something similar by passing along an array that you manually update. It might look something like this:
from jax import lax
import jax.numpy as jnp
from jax import debug
def cumsum(carry, el):
i, v, running_result = carry
v = v + el
running_result = running_result.at[i].set(v)
debug.print("iteration {}: running_result={}", i, running_result)
return (i + 1, v, running_result), v
a = jnp.arange(5)
running_result = jnp.zeros_like(a)
(i, v, running_result), result = lax.scan(cumsum, (0, 0, running_result), a)
print("\nfinal running result:", running_result)
print("final result:", result)
iteration 0: running_result=[0 0 0 0 0]
iteration 1: running_result=[0 1 0 0 0]
iteration 2: running_result=[0 1 3 0 0]
iteration 3: running_result=[0 1 3 6 0]
iteration 4: running_result=[ 0 1 3 6 10]
final running result: [ 0 1 3 6 10]
final result: [ 0 1 3 6 10]
Notice that I used jax.debug.print
to print the intermediate results, because this function is traced and compiled.