I am setting up a Deep MARL framework and I need to assess my actor policies. Ideally, this would entail using jax.vmap over a tuple of actor flax TrainStates. I have tried the following:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import optax
import distrax
class PGActor_1(nn.Module):
@nn.compact
def __call__(self, x):
action_dim = 4
activation = nn.tanh
actor_mean = nn.Dense(128, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0))(x)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
pi = distrax.Categorical(logits=actor_mean)
return pi
class PGActor_2(nn.Module):
@nn.compact
def __call__(self, x):
action_dim = 2
activation = nn.tanh
actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
actor_mean = activation(actor_mean)
actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
pi = distrax.Categorical(logits=actor_mean)
return pi
state= jnp.zeros((1, 5))
network_1 = PGActor_1()
network_1_init_rng = jax.random.PRNGKey(42)
params_1 = network_1.init(network_1_init_rng, state)
network_2 = PGActor_2()
network_2_init_rng = jax.random.PRNGKey(42)
params_2 = network_2.init(network_2_init_rng, state)
tx = optax.chain(
optax.clip_by_global_norm(1),
optax.adam(lr=1e-3)
)
actor_trainstates= (
TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_1),
TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_2)
)
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(actor_trainstates)
but I recieve the following error:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
Does anybody have any idea how to make this work?
Thank you in advance.
This is quite similar to other questions (e.g. Jax - vmap over batch of dataclasses). The key point is that JAX transformations like vmap
require data in a struct of arrays pattern, whereas you are using an array of structs pattern.
To work directly with an array of structs pattern in JAX, you can use Python's built-in map
function – due to JAX's asynchronous dispatch, the resulting operations will be executed in parallel where possible:
pis = map(lambda x: x.apply_fn(x.params, state), actor_trainstates)
However, this doesn't take advantage of the automatic vectorization done by vmap. In order to do this, you can convert your data from an array of structs to a struct of arrays, although this requires that all entries have the same structure.
For compatible cases, the solution would look something like this, however it errors for your data:
train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(train_states_soa)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-da904fa40b9c> in <cell line: 0>()
----> 1 train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)
ValueError: Dict key mismatch; expected keys: ['Dense_0', 'Dense_1', 'Dense_2']
The problem is that your two train states do not have matching structure, and so they cannot be transformed into a single struct of arrays. You can see the difference in structure by inspecting the params:
print(actor_trainstates[0].params['params'].keys()) # dict_keys(['Dense_0', 'Dense_1', 'Dense_2'])
print(actor_trainstates[1].params['params'].keys()) # dict_keys(['Dense_0', 'Dense_1'])
There is no way to use vmap
in a context where your inputs have different structure, so you'll either have to change the problem to ensure the same structure, or stick with the map
approach.