Is it possible in JAX to get a notification whenever a function has to be re-compiled by the JAX just-in-time compiler (because the input changed and the cached compiled version cannot be evaluated)?
For now, I utilize a hacky workaround for being informed on the recompilation. In the current implementation, the tracer executes the function once when it needs to be compiled, and sideeffects are allowed that are thus executed only when the function is recompiled:
import jax
recompilation_count: int = 0
@jax.jit
def func(z):
global recompilation_count
recompilation_count += 1
return z * z + 100 / z
func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)
assert recompilation_count == 2
However, this is an internal of the implementation of JAX, and hence cannot be used in a reliable manner. Is there another way to be informed and potentially prevent recompilation of a function if it happens to frequently?
I don't believe there is any built-in API to do what you are asking. But similar functionality is currently under active discussion (see e.g. https://github.com/google/jax/issues/8655)
But note there is a built-in way to track compilation count, if you wish:
import jax
@jax.jit
def f(x):
return x
print(f._cache_size())
# 0
_ = f(jnp.arange(3))
print(f._cache_size())
# 1
_ = f(jnp.arange(3)) # should not trigger a recompilation
print(f._cache_size())
# 1
_ = f(jnp.arange(100)) # should trigger a recompilation
print(f._cache_size())
# 2