I have a Markov chain function implemented in JAX that advances the chain from state s -> s'
based on some training data (X_train
).
def step(state: dict, key, X_train) -> dict:
new_state = advance(state, key, X_train)
return new_state
Here, state
is a fairly complicated tree-structured dict
of array's that was generated by Haiku. For example,
state = {
'layer1': {
'weights': array(...),
'bias': array(...),
},
'layer2': {
'weights': array(...),
'bias': array(...),
},
}
I would like to run multiple Markov chains, with different states, in parallel. At first glance, jax.vmap function looks like a good candidate. However, state
is not an array but a (tree-structured) dict
.
What is the best way to approach this?
Thanks!
Yes, you could use vmap
for any pytree. But this is how you should construct it:
states = {'layer1':{'weights':jnp.array([[1, -2, 3],
[4, 5, 6]])},
'layer2':{'weights':jnp.array([[1, .2, 3],
[.4, 5, 6]])}}
So in your first run, your weights will be [1, -2, 3]
and [1, .2, 3]
for layer1 and layer2 respectively (second run will be [4, 5, 6]
and [.4, 5, 6]
). But markov chain should be handled by jax.lax.scan. And you could use jit compilation to speed things up. Here is a trivial example. In each step chain calculates the following:
import jax
import jax.numpy as jnp
from functools import partial
@jax.jit
def step(carry, k):
# this function runs a single step in the chain
# X_train dim:(3,3)
# w1 dim: (1,3)
# w2 dim: (3,1)
# X_new = log(Relu(w1@X_old)@w2) + e
# e~Normal(0, 1)
state, X_train, rng = carry
rng, rng_input = jax.random.split(rng)
e = jax.random.normal(rng) # generate pseudorandom
w1 = state['layer1']['weights'] # it is a column vector
w2 = state['layer2']['weights'][None, :] # make it a row vector
X_train = jax.nn.relu(w1@X_train)[:, None]+1
X_train = jnp.log(X_train@w2)
X_train = X_train + e
return [state, X_train, rng], e
@partial(jax.jit, static_argnums = 3)
def fi(state, X_train, key, number_of_steps):
rng = jax.random.PRNGKey(key)
carry = [state, X_train, rng]
carry, random_normals = jax.lax.scan(step, carry, xs = jnp.arange(number_of_steps))
state, X, rng = carry
return X
X_train = jnp.array([[1., -1., 0.5],
[1., 1, 2.],
[4, 2, 0.1]])
states = {'layer1':{'weights':jnp.array([[1, -2, 3],
[4, 5, 6]])},
'layer2':{'weights':jnp.array([[1, .2, 3],
[.4, 5, 6]])}}
vmap_fi = jax.vmap(fi, (0, None, None, None)) # only map on first argument axis 0
key = 42 # random seed
number_of_steps = 100 # chain runs 100 steps
last_states = vmap_fi(states, X_train, key, number_of_steps)
print(last_states)
Output:
[[[ 1.8478627 0.23842478 2.946475 ]
[ 1.3278859 -0.28155205 2.4264982 ]
[ 2.0921988 0.48276085 3.1908112 ]]
[[ 2.9374144 5.4631433 5.645465 ]
[ 3.4333894 5.959118 6.1414394 ]
[ 3.4612248 5.9869533 6.169275 ]]]
In this example, you could make states dictionaries more complicated. You just need to parallelize on their 0
th axis.