I have the following example code that works with a regular map
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]
list(map(f, zip(xs, ys)))
# returns:
[DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32),
DeviceArray(0., dtype=float32)]
How can I use jax.vmap
instead? The naive thing is:
jax.vmap(f)(zip(xs, ys))
but this gives:
ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
For using jax.vmap
, you do not need to zip your variables. You can write what you want like below:
import jax.numpy as jnp
from jax import vmap
def f(x_y):
x, y = x_y
return x.sum() + y.sum()
xs = jnp.zeros((4,3))
ys = jnp.zeros((4,2))
vmap(f)((xs, ys))
Output:
DeviceArray([0., 0., 0., 0.], dtype=float32)