Search code examples
numpymultidimensional-arrayconditional-statementsvectorizationjax

Efficient use of JAX for conditional function evaluation based on an array of integers


I want to efficiently perform conditional function evaluation based on an array of integers and other arrays with real numbers serving as input for those functions. I hope to find a JAX-based solution that provides significant performance improvements over a for-loop approach that I describe below:

import jax
from jax import vmap;
import jax.numpy as jnp
import jax.random as random

def g_0(x, y, z, u):
    return x + y + z + u

def g_1(x, y, z, u):
    return x * y * z * u

def g_2(x, y, z, u):
    return x - y + z - u

def g_3(x, y, z, u):
    return x / y / z / u

g_i = [g_0, g_1, g_2, g_3]
g_i_jit = [jax.jit(func) for func in g_i]

def g_git(i, x, y, z, u):
    return g_i_jit[i](x=x, y=y, z=z, u=u)

def g(i, x, y, z, u):
    return g_i[i](x=x, y=y, z=z, u=u)


len_xyz = 3000
x_ar = random.uniform(random.PRNGKey(0), shape=(len_xyz,))
y_ar = random.uniform(random.PRNGKey(1), shape=(len_xyz,))
z_ar = random.uniform(random.PRNGKey(2), shape=(len_xyz,))

len_u = 1000
u_0 = random.uniform(random.PRNGKey(3), shape=(len_u,))
u_1 = jnp.repeat(u_0, len_xyz)
u_ar = u_1.reshape(len_u, len_xyz)


len_i = 50
i_ar = random.randint(random.PRNGKey(5), shape=(len_i,), minval=0, maxval= len(g_i)) #related to g_range-1


total = jnp.zeros((len_u, len_xyz))

for i in range(len_i):
    total= total + g_git(i_ar[i], x_ar, y_ar, z_ar, u_ar)

The role of "i_ar" is to act as an index that selects one of the four functions from the list g_i. "i_ar" is an array of integers, with each integer representing an index into the g_i list. On the other hand, x_ar, y_ar, z_ar, and u_ar are arrays of real numbers that are inputs to the functions selected by i_ar.

I suspect that this difference in nature between i_ar and x_ar, y_ar, z_ar, and u_ar is what could be difficult to find a JAX way that would be more efficient replacement of the for loop above'. Any ideas how to use JAX (or something else) to replace the foor loop to obtain 'total' more efficiently?

I have tried naively using vmap:

g_git_vmap = jax.vmap(g_git)
total = jnp.zeros((len_u, len_xyz))
total = jnp.sum(g_git_vmap(i_ar, x_ar, y_ar, z_ar, u_ar), axis=0)

but this resulted in error messages and led to nowhere.


Solution

  • Probably the best way to do this is with lax.switch, which allows dynamically switching between multiple functions based on an index array.

    Here's a comparison of your original function with an approach based on lax.switch, with timings on a Colab GPU runtime:

    def f_original(i, x, y, z, u):
      total = jnp.zeros((len(u), len(x)))
      for ii in range(len(i)):
        total += g_git(i[ii], x, y, z, u)
      return total
    
    @jax.jit
    def f_switch(i, x, y, z, u):
      g = lambda i: jax.lax.switch(i, g_i, x, y, z, u)
      return jax.vmap(g)(i).sum(0)
    
    out1 = f_original(i_ar, x_ar, y_ar, z_ar, u_ar)
    out2 = f_switch(i_ar, x_ar, y_ar, z_ar, u_ar)
    np.testing.assert_allclose(out1, out2, rtol=5E-3)
    
    %timeit f_original(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
    # 71 ms ± 23.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    
    %timeit f_switch(i_ar, x_ar, y_ar, z_ar, u_ar).block_until_ready()
    # 4.69 ms ± 37.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)