As the title states, I'd like to know what idiomatic methods are available to raise exceptions or handle errors in JAX jitted functions. The functional nature of JAX makes it unclear how to accomplish this.
The closest official documentation I could find is the jax.experimental.checkify
module, but this wasn't very clear and seemed incomplete.
This Github comment claims that Python exceptions can be raised by using jax.debug.callback()
and jax.lax.cond()
functions. I attempted to do this, but an error is thrown during compilation. A minimum working example is below:
import jax
from jax import jit
def _raise(ex):
raise ex
@jit
def error_if_positive(x):
jax.lax.cond(
x > 0,
lambda : jax.debug.callback(_raise, ValueError("x is positive")),
lambda : None,
)
if __name__ == "__main__":
error_if_positive(-1)
The abbreviated error statement:
TypeError: Value ValueError('x is positive') with type <class 'ValueError'> is not a valid JAX type
You can use callbacks to raise errors, for example:
import jax
from jax import jit
def _raise_if_positive(x):
if x > 0:
raise ValueError("x is positive")
@jit
def error_if_positive(x):
jax.debug.callback(_raise_if_positive, x)
if __name__ == "__main__":
error_if_positive(-1) # no error
error_if_positive(1)
# XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: ValueError: x is positive
The reason your approach didn't work is becuase your error is raised at trace-time rather than at runtime, and both branches of the cond
will always be traced.