Search code examples
pythonjax

Passing the returned stacked output to jax.lax.scan


I wish to pass on the returned stacked values from the jax.lax.scan back to one of its arguments. Is it possible to do so? For example:

from jax import lax


def cumsum(res, el):
    """
    - `res`: The result from the previous loop.
    - `el`: The current array element.
    """
    v, u = res
    print(u)
    v = v + el
    return (v,u),v  # ("carryover", "accumulated")


result_init = 0
result = []

(final,use), result = lax.scan(cumsum, (result_init,result), a)

In the above code, I want to extract the cumulated res values during the runtime and pass it back. Thus, I have passed the result as an argument in the lax function, but it always prints an empty list.


Solution

  • There's no built-in way to access the "current state" of the accumulated values in the course of a scan operation: in particular, the current state will be a dynamically-shaped array (it will have size 0 in the first iteration, size 1 in the second, etc.) and scan, like other JAX transformations and higher-order functions, requires static shapes.

    But you could do something similar by passing along an array that you manually update. It might look something like this:

    from jax import lax
    import jax.numpy as jnp
    from jax import debug
    
    def cumsum(carry, el):
        i, v, running_result = carry
        v = v + el
        running_result = running_result.at[i].set(v)
        debug.print("iteration {}: running_result={}", i, running_result)
        return (i + 1, v, running_result), v
    
    a = jnp.arange(5)
    running_result = jnp.zeros_like(a)
    (i, v, running_result), result = lax.scan(cumsum, (0, 0, running_result), a)
    
    print("\nfinal running result:", running_result)
    print("final result:", result)
    
    iteration 0: running_result=[0 0 0 0 0]
    iteration 1: running_result=[0 1 0 0 0]
    iteration 2: running_result=[0 1 3 0 0]
    iteration 3: running_result=[0 1 3 6 0]
    iteration 4: running_result=[ 0  1  3  6 10]
    
    final running result: [ 0  1  3  6 10]
    final result: [ 0  1  3  6 10]
    

    Notice that I used jax.debug.print to print the intermediate results, because this function is traced and compiled.