Search code examples
pythonmarkov-chainsjax

Execute Markov chains with tree-structured state in parallel with JAX


I have a Markov chain function implemented in JAX that advances the chain from state s -> s' based on some training data (X_train).

def step(state: dict, key, X_train) -> dict:
    new_state = advance(state, key, X_train)
    return new_state

Here, state is a fairly complicated tree-structured dict of array's that was generated by Haiku. For example,

state = {
    'layer1': {
        'weights': array(...),
        'bias': array(...),
    },
    'layer2': {
        'weights': array(...),
        'bias': array(...),
    },
}

I would like to run multiple Markov chains, with different states, in parallel. At first glance, jax.vmap function looks like a good candidate. However, state is not an array but a (tree-structured) dict.

What is the best way to approach this?

Thanks!


Solution

  • Yes, you could use vmap for any pytree. But this is how you should construct it:

    states = {'layer1':{'weights':jnp.array([[1, -2, 3],
                                             [4, 5, 6]])},
              'layer2':{'weights':jnp.array([[1, .2, 3],
                                             [.4, 5, 6]])}}
    

    So in your first run, your weights will be [1, -2, 3] and [1, .2, 3] for layer1 and layer2 respectively (second run will be [4, 5, 6] and [.4, 5, 6]). But markov chain should be handled by jax.lax.scan. And you could use jit compilation to speed things up. Here is a trivial example. In each step chain calculates the following:

    enter image description here

    import jax
    import jax.numpy as jnp 
    from functools import partial
    
    @jax.jit
    def step(carry, k):
        # this function runs a single step in the chain
        # X_train dim:(3,3)
        # w1 dim: (1,3)
        # w2 dim: (3,1)
        # X_new = log(Relu(w1@X_old)@w2) + e
        # e~Normal(0, 1)
        
        state, X_train, rng = carry
        rng, rng_input = jax.random.split(rng)
        e = jax.random.normal(rng) # generate pseudorandom
        w1 = state['layer1']['weights'] # it is a column vector
        w2 = state['layer2']['weights'][None, :] # make it a row vector
        
        X_train = jax.nn.relu(w1@X_train)[:, None]+1
        X_train = jnp.log(X_train@w2)
        X_train = X_train + e
        
        return [state, X_train, rng], e
    
    @partial(jax.jit, static_argnums = 3)
    def fi(state, X_train, key, number_of_steps):
        rng = jax.random.PRNGKey(key)
        carry = [state, X_train, rng]
        carry, random_normals = jax.lax.scan(step, carry, xs = jnp.arange(number_of_steps))
        state, X, rng = carry
        return X
    
    X_train = jnp.array([[1., -1., 0.5],
                         [1., 1, 2.],
                         [4, 2, 0.1]])
    
    states = {'layer1':{'weights':jnp.array([[1, -2, 3],
                                             [4, 5, 6]])},
              'layer2':{'weights':jnp.array([[1, .2, 3],
                                             [.4, 5, 6]])}}
    
    vmap_fi = jax.vmap(fi, (0, None, None, None)) # only map on first argument axis 0
    
    key = 42 # random seed
    number_of_steps = 100 # chain runs 100 steps
    
    last_states = vmap_fi(states, X_train, key, number_of_steps)
    print(last_states)
    

    Output:

    [[[ 1.8478627   0.23842478  2.946475  ]
      [ 1.3278859  -0.28155205  2.4264982 ]
      [ 2.0921988   0.48276085  3.1908112 ]]
    
     [[ 2.9374144   5.4631433   5.645465  ]
      [ 3.4333894   5.959118    6.1414394 ]
      [ 3.4612248   5.9869533   6.169275  ]]]
    

    In this example, you could make states dictionaries more complicated. You just need to parallelize on their 0th axis.