Search code examples
pythonif-statementjitjax

jax return 0 if condition, continue if not in a jitted function


I want to replicate this behaviour in a jitted function (the function is an example):

def function(x,y):
   if y==0:
      return x
   return x+1

Using jax.lax.cond it can be obtained with:

@jax.jit
def function(x,y):
   return jax.lax.cond(y==0, lambda x: x, lambda x: x+1, x)

This is simple as long as whatever needs to be done is simple when y!=0 (in this case, just add 1 to x). However, if that's complex, or there are more conditions of this sort, coding gets more convoluted.

Is there a way to get the behavior "if y==0, return x, if not, just keep running the function. jax.lax.cond requires a new function for every condition that is applied.

For example, this starts to become convoluted.

def function(x,y):
    if y==0:
       return x
    if y>0:
       return x-y
    if y<0:
       return x+y

This starts to be messy:

@jax.jit
def function(x,y):
    jax.lax.cond(y==0, 
             lambda x,y: x, 
             lambda x,y: jax.lax.cond(x>0, lambda x,y:x-y, lambda x,y: x+y, x,y),
             x,y)

Is there a better way?


Solution

  • In short, no, there's no way to return early from a Python function conditioned on traced values. The pattern I typically see to avoid messy nesting is to encapsulate logic in helper functions and call them via lax.cond.

    Alternatively, if you are branching based on multiple conditions, you may be able to better express the logic in terms of lax.switch; for example:

    @jax.jit
    def function(x, y):
      branches = [lambda: x, lambda: x-y, lambda: x+y]
      conditions = jnp.array([y == 0, x > 0, True])
      return lax.switch(jnp.argmax(conditions), branches)