I am going through a great set of JAX notebooks that contain exercises on Eric Ma's GitHub.
The task that I am interested in is to replicate the following function via multiple vmap
function applications:
data = random.normal(key, shape=(11, 31, 7))
def ex2_numpy_equivalent(data):
result = []
for d in data:
cp = jnp.cumprod(d, axis=-1)
s = jnp.sum(cp, axis=1)
result.append(s)
return jnp.stack(result)
The solution provided by Eric was:
def loopless_loop_ex2(data):
"""Data is three-dimensional of shape (n_datasets, n_rows, n_columns)"""
def inner(dataset):
"""dataset is two-dimensional of shape (n_rows, n_columns)"""
cp = vmap(np.cumpred)(dataset)
s = vmap(np.sum)(cp)
return s
return vmap(inner)(data)
My initial "guess solution" was to try:
func1 = vmap(jnp.cumprod)
func2 = vmap(jnp.sum)
func3 = vmap(func2)
func3(data)
Alarm bells were ringing because it makes no sense for func3
to be a vmap'ped version of
func2, considering that
func2is already a
vmap'ped function. While I did get the incorrect answer (value-wise), the dimensions of the answer matched the expected sizes. What is this getting up to?
After looking at Eric's solution, it seems pretty apparent that you can/should only be chaining vmap
functions after you pass or evaluate the nested vmap
function, but when you don't do this, and instead use my incorrect first attempt, what is going on?
The reason your computation is different is that you never compute the cumprod
of the input. You create a function func1 = vmap(jnp.cumprod)
, but never call that function. Rather you're effectively computing vmap(vmap(jnp.sum))(data)
, which for 3-dimensional input is equivalent to the sum along the last axis:
import jax.numpy as jnp
from jax import random, vmap
key = random.PRNGKey(1234)
data = random.normal(key, shape=(11, 31, 7))
# func1 = vmap(jnp.cumprod) # unused
func2 = vmap(jnp.sum)
func3 = vmap(func2)
assert jnp.allclose(func3(data), data.sum(-1))