Search code examples
pythonnumpyjitcontrol-flowjax

How to re-write this control-flow Python function to be JAX-compatible?


I'm rewriting some code from Pure Python to JAX. I have a function that has if/else statements in it that depend on the value of the input variable. I know these kinds of "control flow" statements are not compatible with JAX. But this function is called numerous times by another larger function and that larger function will be jitted. As such, currently that larger function is raising an error when it calls the function below. Is there any way to re-write this problematic control flow function OR tell the larger function to not try to jit-compile this one?

import jax.numpy as jnp
from jax import jit
from functools import partial 

@partial(jit,static_argnums=(1,2,3,))
def test(z,z1=1.,z2=5.,z3=10.):
    
    if z < z1:
        fz = 3.*jnp.sqrt(z) / (z+z1)
    elif z >= z1 and z <= z2:
        fz = (3./z) * z1**2 - z2 
    elif z > z3: 
        fz = (3./z) * z3 
        
    return fz

test(jnp.array(2.0))

"""
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function test at /tmp/ipykernel_2459178/2188894837.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'z'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
"""

Following https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow, I tried to also include z as a static variable by doing static_argnums=(0,1,2,3) but that leads to another error:

ValueError: Non-hashable static arguments are not supported. An error occurred during a call to 'test' while trying to hash an object of type <class 'jaxlib.xla_extension.Array'>, 2.0. The error was: TypeError: unhashable type: 'Array'

Again, the function works fine if I comment out partial-jit decorator:

test(jnp.array(2.0))
# Array(-3.5, dtype=float64, weak_type=True) 

but the problem is that this function (jitted or not) is called within another larger function that is jitted. Is there no way to have that other function's jit bypass jitting this problematic control-flow one? Isn't there a way to only use numpy (rather than jax.numpy) inside functions so that jit basically has no effect on pure python functions called from within larger jitted functions?


Solution

  • One way to do this is by nesting calls to jax.lax.cond (although there's no way to exactly duplicate the behavior of your original function, which results in a NameError if z > z2 and z <= z3, so instead I opted to return NaN):

    import jax.numpy as jnp
    from jax import lax, jit
    
    @jit
    def test(z,z1=1.,z2=5.,z3=10.):
      return lax.cond(
          z < z1,
          lambda: 3.*jnp.sqrt(z) / (z+z1),
          lambda: lax.cond(
              (z >= z1) & (z <= z2),
              lambda: (3./z) * z1**2 - z2,
              lambda: lax.cond(
                  z > z3,
                  lambda: (3./z) * z3,
                  lambda: jnp.nan
              )
          )
      )
    
    print(test(jnp.array(2.0)))
    # -3.5
    

    A slightly more compact approach is using jax.numpy.select:

    @jit
    def test(z,z1=1.,z2=5.,z3=10.):
      return jnp.select(
          [z < z1, (z >= z1) & (z <= z2), z > z3],
          [3.*jnp.sqrt(z) / (z+z1), (3./z) * z1**2 - z2, (3./z) * z3],
          default=jnp.nan)
      
    print(test(jnp.array(2.0)))
    # -3.5
    

    The difference is that in the lax.cond version, only one of the expressions is actually computed, while in the jnp.select version, all expressions are computed but only one is returned.