Search code examples
jaxdm-haiku

Migration from haiku: Alternative to Haiku's PRNGSequence?


I am writing a Markov chain Monte Carlo simulation in JAX which involves a large series of sampling steps. I currently rely on haiku's PRNGSequence to do the pseudo random number generator key bookkeeping:

import haiku as hk

def step(key, context):
  key_seq = hk.PRNGSequence(key)
  x1 = sampler(next(key_seq), context_1)
  ...
  xn = other_sampler(next(key_seq), context_n)

Question:

Since Haiku has been discontinued, I am looking for an alternative to PRNGSequence.

I find the standard JAX approach:

def step(key, context):
  key, subkey = jax.random.split(key)
  x1 = sampler(subkey, context_1)
  ...
  key, subkey = jax.random.split(key)
  xn = other_sampler(subkey, context_n)

unsatisfactory on two accounts:

  • Very error prone: It is easy to slip up and re-use a key. This is especially problematic in MCMC simulations, which are sensitive to these biases and very difficult to debug.
  • It is quite bulky: I need to roughly double the size of my code to split keys.

Any suggestions how to mitigate these problems?

Thanks!

Hylke


Solution

  • If all you need is a simple class that locally handles splitting keys for you, why not define it yourself? You could create a suitable one in a few lines – for example:

    import jax
    
    class PRNGSequence:
      def __init__(self, key):
        self._key = key
      def __next__(self):
        self._key, key = jax.random.split(self._key)
        return key
    
    def step(key):
      key_seq = PRNGSequence(key)
      print(jax.random.uniform(next(key_seq)))
      print(jax.random.uniform(next(key_seq)))
    
    step(jax.random.PRNGKey(0))
    # 0.10536897
    # 0.2787192
    

    As always, though, you have to be careful about this kind of hidden state when you're using JAX transformations like jit: see JAX Sharp Bits: Pure Functions for information on this.