Search code examples
reinforcement-learningjaxflaxmulti-agent-reinforcement-learning

How to use jax.vmap with a tuple of flax TrainStates as input?


I am setting up a Deep MARL framework and I need to assess my actor policies. Ideally, this would entail using jax.vmap over a tuple of actor flax TrainStates. I have tried the following:

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.linen.initializers import constant, orthogonal
from flax.training.train_state import TrainState
import optax
import distrax

class PGActor_1(nn.Module):

   @nn.compact
   def __call__(self, x):
       action_dim = 4
       activation = nn.tanh

       actor_mean = nn.Dense(128, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0))(x)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

class PGActor_2(nn.Module):

   @nn.compact
   def __call__(self, x):
       action_dim = 2
       activation = nn.tanh

       actor_mean = nn.Dense(64, kernel_init=orthogonal(jnp.sqrt(2)), bias_init=constant(0.0)) (actor_mean)
       actor_mean = activation(actor_mean)
       actor_mean = nn.Dense(action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(actor_mean)
       pi = distrax.Categorical(logits=actor_mean)

    return pi

state= jnp.zeros((1, 5))

network_1 = PGActor_1()
network_1_init_rng = jax.random.PRNGKey(42)
params_1 = network_1.init(network_1_init_rng, state)

network_2 = PGActor_2()
network_2_init_rng = jax.random.PRNGKey(42)
params_2 = network_2.init(network_2_init_rng, state)

tx = optax.chain(
optax.clip_by_global_norm(1),
optax.adam(lr=1e-3)
)
actor_trainstates= (
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_1),             
 TrainState.create(apply_fn=network_1.apply, tx=tx, params=params_2)
 )
pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(actor_trainstates)

but I recieve the following error:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Does anybody have any idea how to make this work?

Thank you in advance.


Solution

  • This is quite similar to other questions (e.g. Jax - vmap over batch of dataclasses). The key point is that JAX transformations like vmap require data in a struct of arrays pattern, whereas you are using an array of structs pattern.

    To work directly with an array of structs pattern in JAX, you can use Python's built-in map function – due to JAX's asynchronous dispatch, the resulting operations will be executed in parallel where possible:

    pis = map(lambda x: x.apply_fn(x.params, state), actor_trainstates)
    

    However, this doesn't take advantage of the automatic vectorization done by vmap. In order to do this, you can convert your data from an array of structs to a struct of arrays, although this requires that all entries have the same structure.

    For compatible cases, the solution would look something like this, however it errors for your data:

    train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)
    pis = jax.vmap(lambda x: x.apply_fn(x.params, state))(train_states_soa)
    
    ---------------------------------------------------------------------------
    ValueError                                Traceback (most recent call last)
    <ipython-input-36-da904fa40b9c> in <cell line: 0>()
    ----> 1 train_states_soa = jax.tree.map(lambda *args: jnp.stack(args), *actor_trainstates)
    
    ValueError: Dict key mismatch; expected keys: ['Dense_0', 'Dense_1', 'Dense_2']
    

    The problem is that your two train states do not have matching structure, and so they cannot be transformed into a single struct of arrays. You can see the difference in structure by inspecting the params:

    print(actor_trainstates[0].params['params'].keys())  # dict_keys(['Dense_0', 'Dense_1', 'Dense_2'])
    print(actor_trainstates[1].params['params'].keys())  # dict_keys(['Dense_0', 'Dense_1'])
    

    There is no way to use vmap in a context where your inputs have different structure, so you'll either have to change the problem to ensure the same structure, or stick with the map approach.