Search code examples
python-3.xjax

Calling an initialized function from a list inside a jitted JAX function


Given is a jitted function, which is calling another function that maps over a batch, which again calls a function, i.e. inner_function, to compute a certain property. Also given is a list of initialized functions intialized_functions_dic, from which we want to call the proper initialized function based on some information passed as argument, e.g. info_1. Is there a way to make this work? Thanks in advance.

initialized_functions_dic = {1:init_function1, 2:init_function_2, 3:init_function_3}


def inner_function(info_1, info_2, info_3):
    return 5 + outside_dic[info_1]

Calling outside_dic[info_1] will throw an error due to trying to access a dictionary with a traced value.

Trying to pass info_1 as static_argnums also fails due to info_1 being an unhashable type 'ArrayImpl'.


Solution

  • It sounds like you're looking for jax.lax.switch, which will switch between entries in a list of functions given an index:

    initialized_functions = [init_function_1, init_function_2, init_function_3]
    
    def inner_function(info_1, info_2, info_3):
        idx = info_1 - 1  # lists are zero-indexed
        args = (info_2, info_3) # tuple of arguments to pass to the function
        return 5 + lax.switch(idx, initialized_functions, *args)