I have encountered a scenario where applying jax.grad
to a function with jax.lax.switch
and compound boolean conditions yields jax.errors.TracerBoolConversionError
. A minimal program to reproduce this behavior is the following:
from jax.lax import switch
import jax.numpy as jnp
from jax import grad
func_0 = lambda x: jnp.where(0. < x < 1., x, 0.)
func_1 = lambda x: jnp.where(0. < x < 1., x, 1.)
func_list = [func_0, func_1]
func = lambda index, x: switch(index, func_list, x)
df = grad(func, argnums=1)(1, 2.)
print(df)
The error is the following:
Traceback (most recent call last):
File "***/grad_test.py", line 12, in <module>
df = grad(func, argnums=1)(1, 0.5)
File "***/grad_test.py", line 10, in <lambda>
func = lambda index, x: switch(index, func_list, x)
File "***/grad_test.py", line 5, in <lambda>
func_0 = lambda x: jnp.where(0 < x < 1., x, 0.)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function <lambda> at ***/grad_test.py:5 for switch. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
However, if the boolean condition is changed to a single condition (for example, x < 1
), then no error occurs. I'm wondering if this could be a bug, or otherwise, how the original program should be changed.
You cannot use chained inequalities with JAX or NumPy arrays. Instead of 0 < x < 1
, you should write (0 < x) & (x < 1)
(note that due to operator precedence, the parentheses are not optional here).