I want to replicate this behaviour in a jitted function (the function is an example):
def function(x,y):
if y==0:
return x
return x+1
Using jax.lax.cond
it can be obtained with:
@jax.jit
def function(x,y):
return jax.lax.cond(y==0, lambda x: x, lambda x: x+1, x)
This is simple as long as whatever needs to be done is simple when y!=0
(in this case, just add 1
to x
). However, if that's complex, or there are more conditions of this sort, coding gets more convoluted.
Is there a way to get the behavior "if y==0
, return x
, if not, just keep running the function. jax.lax.cond
requires a new function for every condition that is applied.
For example, this starts to become convoluted.
def function(x,y):
if y==0:
return x
if y>0:
return x-y
if y<0:
return x+y
This starts to be messy:
@jax.jit
def function(x,y):
jax.lax.cond(y==0,
lambda x,y: x,
lambda x,y: jax.lax.cond(x>0, lambda x,y:x-y, lambda x,y: x+y, x,y),
x,y)
Is there a better way?
In short, no, there's no way to return early from a Python function conditioned on traced values. The pattern I typically see to avoid messy nesting is to encapsulate logic in helper functions and call them via lax.cond
.
Alternatively, if you are branching based on multiple conditions, you may be able to better express the logic in terms of lax.switch
; for example:
@jax.jit
def function(x, y):
branches = [lambda: x, lambda: x-y, lambda: x+y]
conditions = jnp.array([y == 0, x > 0, True])
return lax.switch(jnp.argmax(conditions), branches)