suppose I have a function (for simplicity, covariance between two series, though the question is more general):
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
Now I have a "dataframe" D
(a 2-dimenisonal array, whose columns are my series) and I want to vectorize cov
in such a way that the application of the vectorized function produces the covariance matrix. Now, there is an obvious way of doing it:
cov1 = jax.vmap(cov, in_axes=(None, 1))
cov2 = jax.vmap(cov1, in_axes=(1, None))
but that seems a little clunky. Is there a "canonical" way of doing this?
If you want to express logic equivalent to nested for
loops with vmap
, then yes it requires nested vmaps. I think what you've written is probably as canonical as you can get for an operation like this, although it might be slightly more clear if written using decorators:
from functools import partial
@partial(jax.vmap, in_axes=(1, None))
@partial(jax.vmap, in_axes=(None, 1))
def cov(x, y):
return jnp.dot((x-jnp.mean(x)), (y-jnp.mean(y)))
For this particular function, though, note that you can express the same thing using a single dot product if you wish:
result = jnp.dot((x - x.mean(0)).T, (y - y.mean(0)))