Search code examples
pythonpython-3.xjitjax

JAX hook / information / warning when a JIT function is re-compiled


Is it possible in JAX to get a notification whenever a function has to be re-compiled by the JAX just-in-time compiler (because the input changed and the cached compiled version cannot be evaluated)?

For now, I utilize a hacky workaround for being informed on the recompilation. In the current implementation, the tracer executes the function once when it needs to be compiled, and sideeffects are allowed that are thus executed only when the function is recompiled:


import jax
recompilation_count: int = 0

@jax.jit
def func(z):
    global recompilation_count
    recompilation_count += 1
    return z * z + 100 / z


func(1)
print(recompilation_count)
func(2)
print(recompilation_count)
func(jax.numpy.arange(10))
print(recompilation_count)
func(jax.numpy.arange(10, 20))
print(recompilation_count)
func(jax.numpy.arange(10) ** 2)
print(recompilation_count)

assert recompilation_count == 2

However, this is an internal of the implementation of JAX, and hence cannot be used in a reliable manner. Is there another way to be informed and potentially prevent recompilation of a function if it happens to frequently?


Solution

  • I don't believe there is any built-in API to do what you are asking. But similar functionality is currently under active discussion (see e.g. https://github.com/google/jax/issues/8655)

    But note there is a built-in way to track compilation count, if you wish:

    import jax
    
    @jax.jit
    def f(x):
      return x
    
    print(f._cache_size())
    # 0
    
    _ = f(jnp.arange(3))
    print(f._cache_size())
    # 1
    
    _ = f(jnp.arange(3))  # should not trigger a recompilation
    print(f._cache_size())
    # 1
    
    _ = f(jnp.arange(100))  # should trigger a recompilation
    print(f._cache_size())
    # 2