Search code examples
pythonjaxgoogle-jax

Does jax save the jaxpr of jit compiled functions?


Consider the following example:

import jax
import jax.numpy as jnp

@jax.jit
def test(x):
    if x.shape[0] > 4:
        return 1
    else:
        return -1
    
print(test(jnp.ones(8,)))
print(test(jnp.ones(3,)))

The output is

1
-1

However, I thought that on the first call jax compiles a function to use in subsequent calls. Shouldn't this then give the output 1 and 1, because jax traces through an if and does not use a conditional here? In the jaxpr of the first call is no conditional:

{ lambda ; a:f32[8]. let
    b:i32[] = pjit[name=test jaxpr={ lambda ; c:f32[8]. let  in (1,) }] a
  in (b,) }

So how exactly does this work under the hood. Is the jaxpr unique for every call. Does jax only reuse jaxprs if the shape matches? Does jax recompile functions if the shape is different?


Solution

  • JAX does cache the jaxpr and compiled artifact for each compatible call of the function. This compatibility is determined via the cache key, which contains the shape and dtype of array arguments, as well as the hash of any static arguments and some additional information such as global flags that may affect the computation. Any time something in the cache key changes, it results in a new tracing & compilation of the function. You can see this by printing the _cache_size() of the compiled function. For example:

    @jax.jit
    def test(x):
        if x.shape[0] > 4:
            return 1
        else:
            return -1
    
    x8 = jnp.ones(8)
    x3 = jnp.ones(3)
    
    print(test._cache_size())  # no calls yet, so no cache
    # 0
    
    test(x8)
    print(test._cache_size())  # first call caches the jaxpr
    # 1
    
    test(x8)
    print(test._cache_size())  # repeated call, so size doesn't change
    # 1
    
    test(x3)
    print(test._cache_size())  # new call, so size increases
    # 2
    
    test(x8)
    print(test._cache_size())  # repeated call -> size doesn't change
    # 2
    

    By keeping track of these static attributes, jit-compiled functions can change their output based on static attributes, but still avoid recompilation for compatible inputs.