On my path learning Jax, I tried to achieve something like
def f(x):
return [x + 1, [1,2,3], "Hello"]
x = 1
new_x, a_list, str = jnp.where(
x > 0,
test(x),
test(x + 1)
)
Well, Jax clearly does not support this. I tried searching online and went through quite a few docs, but I couldn't find a good answer.
Any help on how can I achieve this in Jax?
In general, JAX functions like jnp.where
only accept array arguments, not list or string arguments. Since you're using a function that is not compatible with JAX in the first place, it might be better to just avoid JAX conditionals and just use standard Python conditionals instead:
import jax.numpy as jnp
def f(x):
return [x + 1, [1,2,3], "Hello"]
x = 1
new_x, a_list, str_ = f(x) if x > 0 else f(x + 1)