Search code examples
pythonjax

Difference between vmapping a vmapped function versus the output of a vmapped function


I am going through a great set of JAX notebooks that contain exercises on Eric Ma's GitHub.

The task that I am interested in is to replicate the following function via multiple vmap function applications:

data = random.normal(key, shape=(11, 31, 7))

def ex2_numpy_equivalent(data):
    result = []
    for d in data: 
        cp = jnp.cumprod(d, axis=-1)
        s = jnp.sum(cp, axis=1)
        result.append(s)
    return jnp.stack(result)

The solution provided by Eric was:

def loopless_loop_ex2(data):
    """Data is three-dimensional of shape (n_datasets, n_rows, n_columns)"""

    def inner(dataset):
        """dataset is two-dimensional of shape (n_rows, n_columns)"""
        cp = vmap(np.cumpred)(dataset)
        s = vmap(np.sum)(cp)
        return s 

    return vmap(inner)(data)

My initial "guess solution" was to try:

func1 = vmap(jnp.cumprod)
func2 = vmap(jnp.sum)
func3 = vmap(func2)

func3(data)

Alarm bells were ringing because it makes no sense for func3 to be a vmap'ped version of func2, considering that func2is already avmap'ped function. While I did get the incorrect answer (value-wise), the dimensions of the answer matched the expected sizes. What is this getting up to?

After looking at Eric's solution, it seems pretty apparent that you can/should only be chaining vmap functions after you pass or evaluate the nested vmap function, but when you don't do this, and instead use my incorrect first attempt, what is going on?


Solution

  • The reason your computation is different is that you never compute the cumprod of the input. You create a function func1 = vmap(jnp.cumprod), but never call that function. Rather you're effectively computing vmap(vmap(jnp.sum))(data), which for 3-dimensional input is equivalent to the sum along the last axis:

    import jax.numpy as jnp
    from jax import random, vmap
    
    key = random.PRNGKey(1234)
    data = random.normal(key, shape=(11, 31, 7))
    
    # func1 = vmap(jnp.cumprod) # unused
    func2 = vmap(jnp.sum)
    func3 = vmap(func2)
    
    assert jnp.allclose(func3(data), data.sum(-1))