Search code examples
jax

how to use jit with control flow and arrays without slowing down


import jax
from jax import jit
import jax.numpy as jnp
import numpy as np


array1 = np.random.normal(size=(1000,1000))
def first():
  for i in range(1000):
    for j in range(1000):
      if array1[i,j] >= 0:
        array1[i,j] = 1
      else:
        array1[i,j] = -1
# %timeit first()

from functools import partial
key = jax.random.PRNGKey(seed=0)
array2 = jax.random.normal(key, shape=(1000,1000))

@partial(jit, static_argnums=(0,1,2))
def f( i,j):
  r = jax.lax.cond(array2[i,j] >= 0, lambda x: 1, lambda x: -1, None)
  # if array2[i,j] >= 0:
  # # if i == j:
  #   array2.at[i,j].set(1)
  # else: array2.at[i,j].set(-1)
  array2.at[i,j].set(r)

# f_jit = jit(f, static_argnums=(0,1))
def second():
  for i in range(1000):
    for j in range(1000):
      # jax.lax.cond(array2[i,j]>=0, lambda x: True, lambda x: False, None)
      f(i,j)
%timeit second()

I have two functions: first and second. I want second to run as fast as (or faster) first does. first function is a function using numpy. second function uses jax. What is the best way to implement first function using jax in this case? jax.lax.cond significantly slows down the process I think.

I left the comments on purpose to show what I've tried.


Solution

  • The reason first runs relatively quickly is because it does 1,000,000 numpy array operations, and numpy has been optimized for fast per-operation dispatch.

    The reason second runs relatively slowly is because it does 1,000,000 JAX array operations, and JAX has not been optimized for fast per-operation dispatch.

    For some general background on this, see JAX FAQ: Is JAX faster than NumPy?.


    But if you're asking about the fastest way to accomplish what you're doing, in both NumPy and JAX the answer would be to avoid writing loops. Here is the equivalent, making the computation pure rather than in-place for ease of comparison (your original second function actually does nothing, because array.at[i].set() does not operate in-place):

    def first_fast(array):
      return np.where(array >= 0, 1, 0)
    
    def second_fast(array):
      return jnp.where(array >= 0, 1, 0)
    

    In general, if you find yourself writing loops over array values in NumPy or in JAX, you can expect that the resulting code will be slow. In both NumPy and JAX, there's almost always a better way to compute the result using built-in vectorized operations.

    If you're interested in further benchmarks between JAX and NumPy, be sure to read FAQ: Benchmarking JAX Code to ensure that you're comparing the right things.