Search code examples
pythonjax

Vectorization guidelnes for jax


suppose I have a function (for simplicity, covariance between two series, though the question is more general):

def cov(x, y):
   return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))

Now I have a "dataframe" D (a 2-dimenisonal array, whose columns are my series) and I want to vectorize cov in such a way that the application of the vectorized function produces the covariance matrix. Now, there is an obvious way of doing it:

cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))

but that seems a little clunky. Is there a "canonical" way of doing this?


Solution

  • If you want to express logic equivalent to nested for loops with vmap, then yes it requires nested vmaps. I think what you've written is probably as canonical as you can get for an operation like this, although it might be slightly more clear if written using decorators:

    from functools import partial
    
    @partial(jax.vmap, in_axes=(1, None))
    @partial(jax.vmap, in_axes=(None, 1))
    def cov(x, y):
       return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
    

    For this particular function, though, note that you can express the same thing using a single dot product if you wish:

    result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))