Search code examples
pythonjax

How to use Jax vmap over zipped arguments?


I have the following example code that works with a regular map

def f(x_y):
    x, y = x_y
    return x.sum() + y.sum()

xs = [jnp.zeros(3) for i in range(4)]
ys = [jnp.zeros(2) for i in range(4)]

list(map(f, zip(xs, ys)))

# returns:
[DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32),
 DeviceArray(0., dtype=float32)]

How can I use jax.vmap instead? The naive thing is:

jax.vmap(f)(zip(xs, ys))

but this gives:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

Solution

  • For using jax.vmap, you do not need to zip your variables. You can write what you want like below:

    import jax.numpy as jnp
    from jax import vmap
    
    def f(x_y):
        x, y = x_y
        return x.sum() + y.sum()
    
    xs = jnp.zeros((4,3))
    ys = jnp.zeros((4,2))
    vmap(f)((xs, ys))
    

    Output:

    DeviceArray([0., 0., 0., 0.], dtype=float32)