Search code examples
jax

Using vmap in jax results in a pytree related error


I am trying to use vmap for batching For the following code


import jax
import jax.numpy as jnp
from jax import random
import numpy as np
from jax import grad, hessian, jacobian, jacfwd
from jax import vmap
from jax.lax import scan

def predict(params, input):
    y = jnp.sin(params*input)
    return y

def loss(params, batch):
    inputs, targets = batch
    predictions = predict(params, inputs)
    return jnp.sum((predictions - targets)**2)

gt_inputs = np.random.random((1,20))
gt_targets = jnp.sin(2.3 * gt_inputs)
batch = np.transpose([gt_inputs, gt_targets]).squeeze()
param_init = 0.2
grads = vmap(grad(loss), (None, 0), 0)(batch)

I am getting the following error for vmap. I am unable to understand what am I doing wrong.

File ~/anaconda3/envs/jaxenv/lib/python3.10/site-packages/jax/_src/api_util.py:418, in flatten_axes(name, treedef, axis_tree, kws, tupled_args)
    414       else:
    415         hint += (f" In particular, you're passing in a single argument which "
    416                  f"means that {name} might need to be wrapped in "
    417                  f"a singleton tuple.")
--> 418   raise ValueError(f"{name} specification must be a tree prefix of the "
    419                    f"corresponding value, got specification {axis_tree} "
    420                    f"for value tree {treedef}.{hint}") from None
    421 axes = [None if a is proxy else a for a in axes]
    422 assert len(axes) == treedef.num_leaves

ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (None, 0) for value tree PyTreeDef((*,)).

Solution

  • The error is misleading, but the issue is that you've defined your loss function to take two arguments. Try this instead:

    grads = vmap(grad(loss), (None, 0), 0)(param_init, batch)