I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function I get a tracing error.
The code I use to get a trajectory state is:
traj_state = vmap(run_trajectory, in_axes=(0, None, 0))(sim_state, timings, lambda_array)
where lambda_array
parameterises each simulation, which is run by the function run_trajectory
which runs a single simulation. I then try to nest this inside a pmap:
pmap(vmap(run_trajectory, in_axes=(0, None, 0)),in_axes=(0, None, 0))(reshaped_sim_state, timings, reshaped_lambda_array)
In doing so I get the error:
While tracing the function run_trajectory for pmap, this concrete value was not available in Python because it depends on the value of the argument 'timings'.
I'm quite new to JAX and although there are documentations on errors with traced values, I'm not very sure on how to navigate this problem.
vmap
and pmap
have slightly different APIs when it comes to in_axes
. In vmap
, setting in_axes=None
causes inputs to be unmapped and static (i.e. un-traced), while in pmap
even inputs with in_axes=None
will be unmapped but still traced:
from jax import vmap, pmap
import jax.numpy as jnp
def f(x, condition):
# requires untraced condition:
return x if condition else x + 1
x = jnp.arange(4)
vmap(f, in_axes=(0, None))(x, True)
# Array([0, 1, 2, 3], dtype=int32)
pmap(f, in_axes=(0, None))(x, True)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected:
To ensure that your variable is untraced in pmap
, you can partially evaluate the function; for example:
from functools import partial
vmap(partial(f, condition=True), in_axes=0)(x)
# Array([0, 1, 2, 3], dtype=int32)
pmap(partial(f, condition=True), in_axes=0)(x)
# Array([0, 1, 2, 3], dtype=int32)
In your case, applying this solution might look like this:
def run(sim_state, lambda_array, timings=timings):
return run_trajectory(sim_state, timings, lambda_array)
vmap(run)(sim_state, lambda_array)
pmap(vmap(run))(reshaped_sim_state, reshaped_lambda_array)