import jax
import numpy as np
import jax.numpy as jnp
a = []
a_jax = []
for i in range(10000):
a.append(np.random.randint(1, 5, (5,)))
a_jax.append(jnp.array(a[i]))
# a_jax = jnp.array(a_jax)
@jax.jit
def calc_add_with_jit(a, b):
return a + b
def calc_add_without_jit(a, b):
return a + b
def main_function_with_jit():
for i in range(99):
calc_add_with_jit(a_jax[i], a_jax[i+1])
def main_function_without_jit():
for i in range(99):
calc_add_without_jit(a[i], a[i+1])
%time calc_add_with_jit(a_jax[1], a_jax[2])
%time main_function_with_jit()
%time main_function_without_jit()
Now the first %time
results in 3.33 ms wall time,
Second %time
function results in 5.58 ms of time,
Third %time
results in 156 microseconds of time
Can anyone explain why is this happening? Why is JAX-JIT slower compared to regular code? I am talking about second and third time function results
This question is pretty well answered in the JAX documentation; see FAQ: Is JAX Faster Than NumPy? In particular, quoting from the summary:
if you’re doing microbenchmarks of individual array operations on CPU, you can generally expect NumPy to outperform JAX due to its lower per-operation dispatch overhead. If you’re running your code on GPU or TPU, or are benchmarking more complicated JIT-compiled sequences of operations on CPU, you can generally expect JAX to outperform NumPy.
You are benchmarking sequences of individually-dispatched single operations on CPU, which is precisely the regime that NumPy is designed and optimized for, and so you can expect that NumPy will be faster.