I have a script which performs some calculations on some given arrays. These calculations are performed thousands of times, so naturally I want to use JAX's JIT decorator to speed up these calculations. I have several functions which are called from some "master function," and I want to JIT-compile the master function. However, there is one function I don't want to be JIT-compiled because it can't be made JIT compatible (or, at least, I don't know how to make it so). Below is an example:
import jax
from functools import partial
import numpy as np
def function(params, X):
# create an array of zeros with same length as x (not X)
# set values to -1 if corresponding value of x (not X) is between specified limits
# otherwise set values to zero
values = jax.numpy.zeros(len(x))
for i in range(len(x)):
if x[i] < params[1] and x[i] > params[0]:
values = values.at[i].set(-1)
X.val = values
return X
# @jax.jit
def master_function(params):
# vmap previous function onto x
partial_function = partial(function, params)
return jax.vmap(partial_function)(x)
# define some variables
params = [4, 6]
x = np.linspace(0, 10, 100)
# run master function
new_x = master_function(params)
# print new_x array
print(new_x)
In this simple example, I have some array x
. I want to then create a copy of that array, called new_x
, where each value is either a 0 or a -1. If a value in x
is between some bounds (specified by params), its value in new_x
should be -1, and zero otherwise. When I don't JIT-compile master_function()
, this script works perfectly. However, when I JIT-compile master_function
, and, by extension, function
, I get the following error:
Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function master_function at temp.py:28 for jit. This concrete value was not available in Python because it depends on the value of the argument 'params'.
I understand that this error is caused by the way JIT-compilation works, so I want to un-JIT-compile function()
while still JIT-compiling master_function
if possible.
You cannot normally1 call an un-jitted function from within a jit-compiled function. In your case it looks like the best solution is to rewrite your function in a way that will be JIT-compatible. You can replace your for-loop with this:
values = jnp.where((x < params[1]) & (x > params[0]), -1.0, 0.0)
Side-note, it looks like you're doing in-place modifications of the val
attribute of a batch tracer, which is not a supported operation and will probably have unexpected consequences. I'd suggest writing your code using standard operations, but the intent of your code is not clear to me so I'm not sure what change to suggest.
1 this actually is possible using pure_callback
, but probably is not what you want because it comes with performance penalties.