I am trying to use vmap for batching For the following code
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from jax import grad, hessian, jacobian, jacfwd
from jax import vmap
from jax.lax import scan
def predict(params, input):
y = jnp.sin(params*input)
return y
def loss(params, batch):
inputs, targets = batch
predictions = predict(params, inputs)
return jnp.sum((predictions - targets)**2)
gt_inputs = np.random.random((1,20))
gt_targets = jnp.sin(2.3 * gt_inputs)
batch = np.transpose([gt_inputs, gt_targets]).squeeze()
param_init = 0.2
grads = vmap(grad(loss), (None, 0), 0)(batch)
I am getting the following error for vmap. I am unable to understand what am I doing wrong.
File ~/anaconda3/envs/jaxenv/lib/python3.10/site-packages/jax/_src/api_util.py:418, in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
414 else:
415 hint += (f" In particular, you're passing in a single argument which "
416 f"means that {name} might need to be wrapped in "
417 f"a singleton tuple.")
--> 418 raise ValueError(f"{name} specification must be a tree prefix of the "
419 f"corresponding value, got specification {axis_tree} "
420 f"for value tree {treedef}.{hint}") from None
421 axes = [None if a is proxy else a for a in axes]
422 assert len(axes) == treedef.num_leaves
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0) for value tree PyTreeDef((*,)).
The error is misleading, but the issue is that you've defined your loss function to take two arguments. Try this instead:
grads = vmap(grad(loss), (None, 0), 0)(param_init, batch)