Is it possible to have callbacks inside a function passed to JAX fori_loop
?
In my case, the callback will save to disk some of the intermediate results produced in the function.
I tried something like this:
def callback(values):
# do something
def diffusion_loop(i, args):
# do something
callback(results)
return results
final_result, _ = jax.lax.fori_loop(0, num_steps, diffusion_loop, (arg1, arg2))
But then if I use final_result
or whatever was saved from the callback
I get an error like this
UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1,4,64,64] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scanned_fun at /usr/local/lib/python3.8/dist-packages/jax/_src/lax/control_flow/loops.py:1606 traced for scan.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<timed exec>:81 (<module>)
<timed exec>:67 (diffusion_loop)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:264 (step)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:472 (step_plms)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
It sounds like you want to do a callback to the host that is impure (i.e. it has a side-effect of saving values to disk) and does not return any values to the runtime. For that, one option is jax.experimental.host_callback.id_tap
, discussed in the docs here.
For example:
import jax
from jax.experimental import host_callback as hcb
def callback(value, transforms):
# do something
print(f"callback: {value}")
def diffusion_loop(i, args):
hcb.id_tap(callback, i)
return args
args = (1, 2)
result, _ = jax.lax.fori_loop(0, 5, diffusion_loop, args)
callback: 0
callback: 1
callback: 2
callback: 3
callback: 4