In JAX I find myself needing a PRNGKey
that changes on each iteration of a loop. I'm not sure of the best pattern. I've considered
a) split
for i in range(N):
rng, _ = jax.random.split(rng)
# Alternatively.
rng = jax.random.split(rng, 1)[0]
b) fold_in
for i in range(N):
rng = jax.random.fold_in(rng, i)
c) use the iterator index? seems bad since the rng doesn't depend on a prior rng.
for i in range(N):
rng = jax.random.PRNGKey(i)
Which of these is the best pattern and why? I am leaning towards (b) as it maintains dependency on the previous rng key (e.g. passed in as an argument) but im not sure if this is really the intended use-case for jax.random.fold_in
JAX docs (including the PRNG design doc) recommend something similar to (a):
for i in range(N):
key, subkey = jax.random.split(key)
values = random.uniform(subkey, shape)
# key carries over to the next iteration
The reason this is better than splitting and throwing away the subkey is that it ensures that the streams in each iteration are independent.
Your option (b) is also safe, and in fact is the pattern that developers had in mind when creating fold_in
(see e.g. https://github.com/google/jax/discussions/12395).
If you have a fixed number of iterations, it may be better to do all the splits once; for example:
for i, key in enumerate(random.split(key, N)):
values = random.uniform(key, shape)
Or if your iterations do not have sequential dependence, it's better to use vmap
to vectorize the operation:
def f(key):
return random.uniform(key, shape)
jax.vmap(f)(random.split(key, N))