Consider the following toy example:
x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# for a 2d x:
x = np.arange(6).reshape(3,2)
cfuns(x)
where x-x[:,None]
is the broadcasting part and give a 3x3 array.
I want cfuns to be vectorized over each row of x.
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[2,2])>with<BatchTrace(level=1/0)> with
val = Array([[[ 0, 1],
[-1, 0]],
[[ 0, 1],
JAX transformations like vmap
, jit
, grad
, etc. are not compatible with standard numpy
operations. Instead you should use jax.numpy
, which provides a similar API built on JAX-compatible operations:
import jax
import jax.numpy as jnp
x = jnp.arange(3)
cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
cfuns = jax.vmap(cfun)
# for a 2d x:
x = jnp.arange(6).reshape(3,2)
print(cfuns(x))
# [[ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]
# [ 0.84147096 -0.84147096]]