Search code examples
jax

Struggling to understand nested vmaps in JAX


I just about understand unnested vmaps, but try as I may, and I have tried my darnedest, nested vmaps continue to elude me. Take the snippet from this text for example

enter image description here

I don't understand what the axis are in this case. Is the nested vmap(kernel, (0, None)) some sort of partial function application? Why is the function mapped twice? Can someone please explain what is going on behind the scene in other words. What does a nested vmap desugar to?? All the answers that I have found are variants of the same curt explanation: mapping over both axis, which I am struggling with.


Solution

  • Each time vmap is applied, it maps over a single axis. So say for simplicity that you have a function that takes two scalars and outputs a scalar:

    def f(x, y):
      assert jnp.ndim(x) == jnp.ndim(y) == 0  # x and y are scalars
      return x + y
    
    print(f(1, 2))
    # 0
    

    If you want to apply this function to an array of x values and a single y value, you can do this with vmap:

    f_mapped_over_x = jax.vmap(f, in_axes=(0, None))
    
    x = jnp.arange(5)
    print(f_mapped_over_x(x, 1))
    # [1 2 3 4 5]
    

    in_axes=(0, None) means that it is mapped along the leading axis of the first argument, x, and there is no mapping of the second argument, y.

    Likewise, if you want to apply this function to a single x value and an array of y values, you can specify this via in_axes:

    f_mapped_over_y = jax.vmap(f, in_axes=(None, 0))
    
    y = jnp.arange(5, 10)
    print(f_mapped_over_y(1, y))
    # [ 6  7  8  9 10]
    

    If you wish to map the function over both arrays at once, you can do this by specifying in_axes=(0, 0), or equivalently in_axes=0:

    f_mapped_over_x_and_y = jax.vmap(f, in_axes=(0, 0))
    
    print(f_mapped_over_x_and_y(x, y))
    # [ 5  7  9 11 13]
    

    But suppose you want to map first over x, then over y, to get a sort of "outer-product" version of the function. You can do this via a nested vmap, first mapping over just x, then mapping over just y:

    f_mapped_over_x_then_y = jax.vmap(jax.vmap(f, in_axes=(None, 0)), in_axes=(0, None))
    
    print(f_mapped_over_x_then_y(x, y))
    # [[ 5  6  7  8  9]
    #  [ 6  7  8  9 10]
    #  [ 7  8  9 10 11]
    #  [ 8  9 10 11 12]
    #  [ 9 10 11 12 13]]
    

    The nesting of vmaps is what lets you map over two axes separately.