I have found that vmap
in JAX does not behave as expected when applied to multiple arguments. For example, consider the function below:
def f1(x, y, z):
f = x[:, None, None] * z[None, None, :] + y[None, :, None]
return f
For x = jnp.arange(7), y = jnp.arange(5), z = jnp.arange(3)
, the output of this function has shape (7, 5, 3)
. However, for the vmap version below:
@partial(vmap, in_axes=(None, 0, 0), out_axes=(1, 2))
def f2(x, y, z):
f = x*z + y
return f
It outputs this error:
ValueError: vmap got inconsistent sizes for array axes to be mapped:
* one axis had size 5: axis 0 of argument y of type int32[5];
* one axis had size 3: axis 0 of argument z of type int32[3]
Could someone kindly explain what's behind this error?
The semantics of vmap
are that it does a single batching operation along one or more arrays. When you specify in_axes=(None, 0, 0)
, the meaning is "map simultaneously along the leading dimension of y
and z
": the error you're seeing is telling you that the leading dimensions of y
and z
have different sizes, and so they are not compatible for batching.
Your function f1
essentially uses broadcasting to encode three batching operations, so to replicate that logic with vmap
you'll need three applications of vmap
. You can express that as follows:
@partial(vmap, in_axes=(0, None, None))
@partial(vmap, in_axes=(None, 0, None))
@partial(vmap, in_axes=(None, None, 0))
def f2(x, y, z):
f = x*z + y
return f