Search code examples
pythonbooleangradientjax

JAX `grad` error for function with `jax.lax.switch` and compound boolean conditions


I have encountered a scenario where applying jax.grad to a function with jax.lax.switch and compound boolean conditions yields jax.errors.TracerBoolConversionError. A minimal program to reproduce this behavior is the following:

from jax.lax import switch
import jax.numpy as jnp
from jax import grad

func_0 = lambda x: jnp.where(0. < x < 1., x, 0.)
func_1 = lambda x: jnp.where(0. < x < 1., x, 1.)

func_list = [func_0, func_1]

func = lambda index, x: switch(index, func_list, x)

df = grad(func, argnums=1)(1, 2.)
print(df)

The error is the following:

Traceback (most recent call last):
  File "***/grad_test.py", line 12, in <module>
    df = grad(func, argnums=1)(1, 0.5)
  File "***/grad_test.py", line 10, in <lambda>
    func = lambda index, x: switch(index, func_list, x)
  File "***/grad_test.py", line 5, in <lambda>
    func_0 = lambda x: jnp.where(0 < x < 1., x, 0.)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function <lambda> at ***/grad_test.py:5 for switch. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

However, if the boolean condition is changed to a single condition (for example, x < 1), then no error occurs. I'm wondering if this could be a bug, or otherwise, how the original program should be changed.


Solution

  • You cannot use chained inequalities with JAX or NumPy arrays. Instead of 0 < x < 1, you should write (0 < x) & (x < 1) (note that due to operator precedence, the parentheses are not optional here).