Search code examples
jax

Is there a way to return tuple of mixed variables in Jax helper function?


On my path learning Jax, I tried to achieve something like

def f(x):
    return [x + 1, [1,2,3], "Hello"]

x = 1
new_x, a_list, str = jnp.where(
    x > 0,
    test(x),
    test(x + 1)
)

Well, Jax clearly does not support this. I tried searching online and went through quite a few docs, but I couldn't find a good answer.

Any help on how can I achieve this in Jax?


Solution

  • In general, JAX functions like jnp.where only accept array arguments, not list or string arguments. Since you're using a function that is not compatible with JAX in the first place, it might be better to just avoid JAX conditionals and just use standard Python conditionals instead:

    import jax.numpy as jnp
    
    def f(x):
        return [x + 1, [1,2,3], "Hello"]
    
    x = 1
    
    new_x, a_list, str_ = f(x) if x > 0 else f(x + 1)