Search code examples
pythonnumpyjax

Vectorise nested vmap


Here's some data I have:

import jax.numpy as jnp
import numpyro.distributions as dist
import jax

xaxis = jnp.linspace(-3, 3, 5)
yaxis = jnp.linspace(-3, 3, 5)

I'd like to run the function

def func(x, y):
    return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))

over each pair of values from xaxis and yaxis.

Here's a "slow" way to do:

results = np.zeros((len(xaxis), len(yaxis)))

for i in range(len(xaxis)):
    for j in range(len(yaxis)):
        results[i, j] = func(xaxis[i], yaxis[j])

Works, but it's slow.

So here's a vectorised way of doing it:

jax.vmap(lambda axis: jax.vmap(func, (None, 0))(axis, yaxis))(xaxis)

Much faster, but it's hard to read.

Is there a clean way of writing the vectorised version? Can I do it with a single vmap, rather than having to nest one within another one?

EDIT

Another way would be

jax.vmap(func)(xmesh.flatten(), ymesh.flatten()).reshape(len(xaxis), len(yaxis)).T

but it's still messy.


Solution

  • I believe Vectorization guidelnes for jax is quite similar to your question; to replicate the logic of nested for-loops with vmap requires nested vmaps.

    The cleanest approach using jax.vmap is probably something like this:

    from functools import partial
    
    @partial(jax.vmap, in_axes=(0, None))
    @partial(jax.vmap, in_axes=(None, 0))
    def func(x, y):
        return dist.MultivariateNormal(jnp.zeros(2), jnp.array([[.5, .2], [.2, .1]])).log_prob(jnp.asarray([x, y]))
    
    func(xaxis, yaxis)
    

    Another option here is to use the jnp.vectorize API (which is implemented via multiple vmaps), in which case you can do something like this:

    print(jnp.vectorize(func)(xaxis[:, None], yaxis))