I have a JAX function that, given the order and the index, selects a polynomial from a pre-defined dictionary, as follows:
poly_dict = {
(0, 0): lambda x, y, z: 1.,
(1, 0): lambda x, y, z: x,
(1, 1): lambda x, y, z: y,
(1, 2): lambda x, y, z: z,
(2, 0): lambda x, y, z: x*x,
(2, 1): lambda x, y, z: y*y,
(2, 2): lambda x, y, z: z*z,
(2, 3): lambda x, y, z: x*y,
(2, 4): lambda x, y, z: y*z,
(2, 5): lambda x, y, z: z*x
}
def poly_func(order: int, index: int):
try:
return poly_dict[(order, index)]
except KeyError:
print("(order, index) must be a key in poly_dict!")
return
Now I want to jit poly_func()
, but it gives an error TypeError: unhashable type: 'DynamicJaxprTracer'
Moreover, if I just do
def poly_func(order: int, index: int):
return poly_dict[(order, index)]
it still gives the same error. Is there a way to resolve this issue?
There are two issues with your approach: first, traced values cannot be used to index into Python collections like dicts or lists. Second, JIT-compiled functions can only return array values, not functions.
With that in mind, it is impossible to make the function you propose work with JAX transforms like JIT. But you could modify your approach to use a list rather than a dict of functions, and then use lax.switch
to dynamically select which function you call. Here's how it might look:
def get_index(order, index):
return order * 5 + index
poly_list = 16 * [lambda x, y, z: 0.0] # place-holder function
for key, val in poly_dict.items():
poly_list[get_index(*key)] = poly_dict[key]
def eval_poly_func(order: int, index: int, args):
ind = get_index(order, index)
return jax.lax.switch(ind, poly_list, *args)
result = jax.jit(eval_poly_func)(2, 0, (5.0, 6.0, 7.0))
print(result)
# 25.0