Search code examples
pythonparallel-processingjax

Nested vmap in pmap - JAX


I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function I get a tracing error.

The code I use to get a trajectory state is:

traj_state = vmap(run_trajectory, in_axes=(0, None, 0))(sim_state, timings, lambda_array)
                                                                        

where lambda_array parameterises each simulation, which is run by the function run_trajectory which runs a single simulation. I then try to nest this inside a pmap:

pmap(vmap(run_trajectory, in_axes=(0, None, 0)),in_axes=(0, None, 0))(reshaped_sim_state, timings, reshaped_lambda_array)                                                                                       

In doing so I get the error:

While tracing the function run_trajectory for pmap, this concrete value was not available in Python because it depends on the value of the argument 'timings'.

I'm quite new to JAX and although there are documentations on errors with traced values, I'm not very sure on how to navigate this problem.


Solution

  • vmap and pmap have slightly different APIs when it comes to in_axes. In vmap, setting in_axes=None causes inputs to be unmapped and static (i.e. un-traced), while in pmap even inputs with in_axes=None will be unmapped but still traced:

    from jax import vmap, pmap
    import jax.numpy as jnp
    
    def f(x, condition):
      # requires untraced condition:
      return x if condition else x + 1
    
    x = jnp.arange(4)
    vmap(f, in_axes=(0, None))(x, True)
    # Array([0, 1, 2, 3], dtype=int32)
    
    pmap(f, in_axes=(0, None))(x, True)
    # ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: 
    

    To ensure that your variable is untraced in pmap, you can partially evaluate the function; for example:

    from functools import partial
    
    vmap(partial(f, condition=True), in_axes=0)(x)
    # Array([0, 1, 2, 3], dtype=int32)
    
    pmap(partial(f, condition=True), in_axes=0)(x)
    # Array([0, 1, 2, 3], dtype=int32)
    

    In your case, applying this solution might look like this:

    def run(sim_state, lambda_array, timings=timings):
      return run_trajectory(sim_state, timings, lambda_array)
    
    vmap(run)(sim_state, lambda_array)
    
    pmap(vmap(run))(reshaped_sim_state, reshaped_lambda_array)