Search code examples
pythonjax

JAX best way to iterate RNGKeys?


In JAX I find myself needing a PRNGKey that changes on each iteration of a loop. I'm not sure of the best pattern. I've considered

a) split

for i in range(N):
  rng, _ = jax.random.split(rng)

  # Alternatively.
  rng = jax.random.split(rng, 1)[0]

b) fold_in

for i in range(N):
  rng = jax.random.fold_in(rng, i)

c) use the iterator index? seems bad since the rng doesn't depend on a prior rng.

for i in range(N):
 rng = jax.random.PRNGKey(i)

Which of these is the best pattern and why? I am leaning towards (b) as it maintains dependency on the previous rng key (e.g. passed in as an argument) but im not sure if this is really the intended use-case for jax.random.fold_in


Solution

  • JAX docs (including the PRNG design doc) recommend something similar to (a):

    for i in range(N):
      key, subkey = jax.random.split(key)
      values = random.uniform(subkey, shape)
      # key carries over to the next iteration
    

    The reason this is better than splitting and throwing away the subkey is that it ensures that the streams in each iteration are independent.

    Your option (b) is also safe, and in fact is the pattern that developers had in mind when creating fold_in (see e.g. https://github.com/google/jax/discussions/12395).

    If you have a fixed number of iterations, it may be better to do all the splits once; for example:

    for i, key in enumerate(random.split(key, N)):
      values = random.uniform(key, shape)
    

    Or if your iterations do not have sequential dependence, it's better to use vmap to vectorize the operation:

    def f(key):
      return random.uniform(key, shape)
    
    jax.vmap(f)(random.split(key, N))