Search code examples
pythonjax

Idiomatic ways to handle errors in JAX jitted functions


As the title states, I'd like to know what idiomatic methods are available to raise exceptions or handle errors in JAX jitted functions. The functional nature of JAX makes it unclear how to accomplish this.

The closest official documentation I could find is the jax.experimental.checkify module, but this wasn't very clear and seemed incomplete.

This Github comment claims that Python exceptions can be raised by using jax.debug.callback() and jax.lax.cond() functions. I attempted to do this, but an error is thrown during compilation. A minimum working example is below:

import jax
from jax import jit

def _raise(ex):
    raise ex


@jit
def error_if_positive(x):
    jax.lax.cond(
        x > 0,
        lambda : jax.debug.callback(_raise, ValueError("x is positive")),
        lambda : None,
    )

if __name__ == "__main__":

    error_if_positive(-1)

The abbreviated error statement:

TypeError: Value ValueError('x is positive') with type <class 'ValueError'> is not a valid JAX type

Solution

  • You can use callbacks to raise errors, for example:

    import jax
    from jax import jit
    
    def _raise_if_positive(x):
      if x > 0:
        raise ValueError("x is positive")
    
    @jit
    def error_if_positive(x):
      jax.debug.callback(_raise_if_positive, x)
    
    if __name__ == "__main__":
      error_if_positive(-1)  # no error
      error_if_positive(1)
      # XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: x is positive
    

    The reason your approach didn't work is becuase your error is raised at trace-time rather than at runtime, and both branches of the cond will always be traced.