Search code examples
pythonjax

using jax.vmap to vectorize along with broadcasting


Consider the following toy example:

x = np.arange(3)
# np.sum(np.sin(x - x[:, np.newaxis]), axis=1)

cfun = lambda x: np.sum(np.sin(x - x[:, np.newaxis]), axis=1)
cfuns = jax.vmap(cfun)

# for a 2d x:
x = np.arange(6).reshape(3,2)
cfuns(x)

where x-x[:,None] is the broadcasting part and give a 3x3 array. I want cfuns to be vectorized over each row of x.

The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[2,2])>with<BatchTrace(level=1/0)> with
  val = Array([[[ 0,  1],
        [-1,  0]],

       [[ 0,  1],

Solution

  • JAX transformations like vmap, jit, grad, etc. are not compatible with standard numpy operations. Instead you should use jax.numpy, which provides a similar API built on JAX-compatible operations:

    import jax
    import jax.numpy as jnp
    
    x = jnp.arange(3)
    
    cfun = lambda x: jnp.sum(jnp.sin(x - x[:, jnp.newaxis]), axis=1)
    cfuns = jax.vmap(cfun)
    
    # for a 2d x:
    x = jnp.arange(6).reshape(3,2)
    
    print(cfuns(x))
    # [[ 0.84147096 -0.84147096]
    #  [ 0.84147096 -0.84147096]
    #  [ 0.84147096 -0.84147096]]