Search code examples
numpymultidimensional-arrayindexingarray-broadcastingjax

Defining the correct vectorization axes for JAX vmap with arrays of different shapes and sizes


Following the answer to this post, the following function that 'f_switch' that dynamically switches between multiple functions based on an index array is defined (based on 'jax.lax.switch'):

import jax
from jax import vmap;
import jax.random as random

def g_0(x, y, z, u): return x + y + z + u
def g_1(x, y, z, u): return x * y * z * u
def g_2(x, y, z, u): return x - y + z - u
def g_3(x, y, z, u): return x / y / z / u
g_i = [g_0, g_1, g_2, g_3]


@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
  return jax.vmap(g)(i)

With input arrays: i_ar of shape (len_i,), x_ar y_ar and z_ar of shapes (len_xyz,) and u_ar of shape (len_u, len_xyz), out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar), yields out of shape (len_i, len_xyz, len_u):

len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to 

len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)

out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
print('The shape of out is', out.shape)

This worked. **But, How can the f_switch function be defined such that the result out of out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar) has a shape of (j_len, k_len, l_len) when the function is applied along the following axes: i_ar[j], x_ar[j], y_ar[j, k], z_ar[j, k], u_ar[l]? I am not sure about how ** Examples of these input arrays are here:

j_len = 82;
k_len = 20;
l_len = 100;
i_ar = random.randint(random.PRNGKey(0), shape=(j_len,), minval=0, maxval=len(g_i))
x_ar = random.uniform(random.PRNGKey(1), shape=(j_len,))
y_ar = random.uniform(random.PRNGKey(2), shape=(j_len,k_len))
z_ar = random.uniform(random.PRNGKey(3), shape=(j_len,k_len))
u_ar = random.uniform(random.PRNGKey(4), shape=(l_len,))

I tried to resolve this (i.e. with given input array to get output of shape: (j_len, k_len, l_len), with a nested vmap:

@jax.jit
def f_switch(i, x, y, z, u):
  g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
  g_map = jax.vmap(g, in_axes=(None, 0, 0, 0, 0))
  wrapper = lambda x, y, z, u: g_map(i, x, y, z, u)
  return jax.vmap(wrapper, in_axes=(0, None, None, None, 0))(x, y, z, u)

and to broadcast u_ar: u_ar_broadcast = jnp.broadcast_to(u_ar, (j_len, k_len, l_len)), and then apply it inside of the original f_switch. But, both of these attempts failed.


Solution

  • It looks like maybe you want something like this?

    @jax.jit
    def f_switch(i, x, y, z, u):
      g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
      g = jax.vmap(g, (None, None, None, None, 0))
      g = jax.vmap(g, (None, None, 0, 0, None))
      g = jax.vmap(g, (0, 0, 0, 0, None))
      return g(i, x, y, z, u)
    
    out = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
    print(out.shape)
    # (82, 20, 100)
    

    You should read the in_axes from bottom to top (because the bottom vmap is the outer one, and is therefore applied to the inputs first). Schematically, you can think of the effect of the maps on the shapes as something like this:

                                   (i[82], x[82], y[82,20], z[82,20], u[100])
    (0, 0, 0, 0, None)          -> (i,     x,     y[20],    z[20],    u[100])
    (None, None, 0, 0, None)    -> (i,     x,     y,        z,        u[100])
    (None, None, None, None, 0) -> (i,     x,     y,        z,        u)
    

    That said, often it is easier to rely on numpy-style broadcasting rather than on multiple nested vmaps. For example, you could also do something like this:

    @jax.jit
    def f_switch(i, x, y, z, u):
      g = lambda i, x, y, z, u: jax.lax.switch(i, g_i, x, y, z, u)
      return jax.vmap(g, in_axes=(0, 0, 0, 0, None))(i, x, y, z, u)
    
    out = f_switch(i_ar, x_ar[:, None, None], y_ar[:, :, None], z_ar[:, :, None], u_ar)
    print(out.shape)
    # (82, 20, 100)