Given is a jitted function, which is calling another function that maps over a batch, which again calls a function, i.e. inner_function, to compute a certain property. Also given is a list of initialized functions intialized_functions_dic, from which we want to call the proper initialized function based on some information passed as argument, e.g. info_1. Is there a way to make this work? Thanks in advance.
initialized_functions_dic = {1:init_function1, 2:init_function_2, 3:init_function_3}
def inner_function(info_1, info_2, info_3):
return 5 + outside_dic[info_1]
Calling outside_dic[info_1] will throw an error due to trying to access a dictionary with a traced value.
Trying to pass info_1 as static_argnums also fails due to info_1 being an unhashable type 'ArrayImpl'.
It sounds like you're looking for jax.lax.switch
, which will switch between entries in a list of functions given an index:
initialized_functions = [init_function_1, init_function_2, init_function_3]
def inner_function(info_1, info_2, info_3):
idx = info_1 - 1 # lists are zero-indexed
args = (info_2, info_3) # tuple of arguments to pass to the function
return 5 + lax.switch(idx, initialized_functions, *args)