Search code examples
jitjax

using jax lax scan with inputs that don't change across iterations within scan but are different each time scan is called


Jax lax scan operates on a function that takes two arguments, a carry and a sequence of inputs. I am wondering how scan should be called if some inputs don't change across iterations of the scan. Naively, I could create a sequence of identical inputs, but this seems wasteful/redundant and more importantly, this isn't always possible, as scan can only scan over arrays. For example, one of the inputs I want to pass to my function is a train state (e.g. from flax.training import train_state) that contains my model and its parameters, which cannot be put into array. As I say in the title, these inputs may also change each time I call scan (e.g. the model parameters will change).

Any ideas on how best to do this?

Thanks.


Solution

  • In general, you have three possible approaches for this:

    1. Convert the single value into a sequence of identical inputs to scan over
    2. Put the single value in the carry to carry it along to each step of the scan
    3. Close over the single value

    Here are three examples of a computation using these strategies:

    import jax
    import jax.numpy as jnp
    
    a = jnp.arange(5)
    b = 2
    
    # Strategy 1: duplicate b across sequence
    def f(carry, xs):
      a, b = xs
      result = a * b
      return carry + result, result
    
    b_seq = jnp.full_like(a, b)
    
    total, cumulative = jax.lax.scan(f, 0, (a, b_seq))
    print(total) # 20
    print(cumulative) # [0 2 4 6 8]
    
    # Strategy 2: put b in the carry
    def f(carry, xs):
      carry, b = carry
      a = xs
      result = a * b
      return (carry + result, b), result
    
    (total, _), cumulative = jax.lax.scan(f, (0, b), a)
    print(total) # 20
    print(cumulative) # [0 2 4 6 8]
    
    # Strategy 3: close over b
    from functools import partial
    
    def f(carry, xs, b):
      a = xs
      result = a * b
      return carry + result, result
    
    total, cumulative = jax.lax.scan(partial(f, b=b), 0, a)
    print(total) # 20
    print(cumulative) # [0 2 4 6 8]
    

    Which you use probably depends on the context of where you are using it, but I personally think closure (option 3) is probably the cleanest approach.