Search code examples
pythonjax

Issue with jax.lax.scan


I am supposed to use Jax.lax.scan instead of a for loop with 100 iterations at line 22. I am supposed to update S and append it to S_list. I am unsure how to fix the jax.lax.scan. The error that keeps popping up is missing the required XS. When I put a value for XS it says that my length argument doesn't line up with the axis sizes. Here is my code. Can you help me?


Solution

  • You're not calling scan with the correct signature. You can find more information on the call signature in the jax.lax.scan docs. It makes clear, for example, that your step function must accept two arguments and return two arguments.

    From looking at your code, it looks like you're intending to do something like this:

    @jax.jit
    def simulate():
      key = jax.random.PRNGKey(0)
      def step(S, _):
        dZ = jax.random.normal(key, shape=(S.size,)) * jnp.sqrt(dt)
        dS = r * S  * dt + σ  * S  * dZ
        return S + dS, S
      S0 = jnp.ones(20000)
      _, S_array = jax.lax.scan(step, S0, xs=None, length=m)
      return S_array
    

    In particular, from the docs you can see that the S_list.append(...) and S_array = jnp.stack(S_list) are effectively part of the scan function itself, so you don't have to do that yourself after calling it.

    Hope that helps!