Search code examples
pythontimegoogle-colaboratoryjitjax

If there are two functions - one with jit and other without, and when I iterate them for 100 times, unjit function gives me a less time than jit one


import jax
import numpy as np
import jax.numpy as jnp
a = []
a_jax = []
for i in range(10000):
 a.append(np.random.randint(1, 5, (5,)))
 a_jax.append(jnp.array(a[i]))

# a_jax = jnp.array(a_jax)
@jax.jit
def calc_add_with_jit(a, b):
 return a + b
def calc_add_without_jit(a, b):
 return a + b
def main_function_with_jit():
 for i in range(99):
  calc_add_with_jit(a_jax[i], a_jax[i+1]) 
def main_function_without_jit():
 for i in range(99):
  calc_add_without_jit(a[i], a[i+1])

%time calc_add_with_jit(a_jax[1], a_jax[2])
%time main_function_with_jit()
%time main_function_without_jit()

Now the first %time results in 3.33 ms wall time, Second %time function results in 5.58 ms of time, Third %time results in 156 microseconds of time

Can anyone explain why is this happening? Why is JAX-JIT slower compared to regular code? I am talking about second and third time function results


Solution

  • This question is pretty well answered in the JAX documentation; see FAQ: Is JAX Faster Than NumPy? In particular, quoting from the summary:

    if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.

    You are benchmarking sequences of individually-dispatched single operations on CPU, which is precisely the regime that NumPy is designed and optimized for, and so you can expect that NumPy will be faster.