I am writing a Markov chain Monte Carlo simulation in JAX which involves a large series of sampling steps. I currently rely on haiku's PRNGSequence to do the pseudo random number generator key bookkeeping:
import haiku as hk
def step(key, context):
key_seq = hk.PRNGSequence(key)
x1 = sampler(next(key_seq), context_1)
...
xn = other_sampler(next(key_seq), context_n)
Question:
Since Haiku has been discontinued, I am looking for an alternative to PRNGSequence.
I find the standard JAX approach:
def step(key, context):
key, subkey = jax.random.split(key)
x1 = sampler(subkey, context_1)
...
key, subkey = jax.random.split(key)
xn = other_sampler(subkey, context_n)
unsatisfactory on two accounts:
Any suggestions how to mitigate these problems?
Thanks!
Hylke
If all you need is a simple class that locally handles splitting keys for you, why not define it yourself? You could create a suitable one in a few lines – for example:
import jax
class PRNGSequence:
def __init__(self, key):
self._key = key
def __next__(self):
self._key, key = jax.random.split(self._key)
return key
def step(key):
key_seq = PRNGSequence(key)
print(jax.random.uniform(next(key_seq)))
print(jax.random.uniform(next(key_seq)))
step(jax.random.PRNGKey(0))
# 0.10536897
# 0.2787192
As always, though, you have to be careful about this kind of hidden state when you're using JAX transformations like jit
: see JAX Sharp Bits: Pure Functions for information on this.