Search code examples
jax

Callback in JAX fori_loop


Is it possible to have callbacks inside a function passed to JAX fori_loop?

In my case, the callback will save to disk some of the intermediate results produced in the function.

I tried something like this:

def callback(values):
  # do something

def diffusion_loop(i, args):
  # do something
  callback(results)
  return results

final_result, _ = jax.lax.fori_loop(0, num_steps, diffusion_loop, (arg1, arg2))

But then if I use final_result or whatever was saved from the callback I get an error like this

UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[1,4,64,64] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was scanned_fun at /usr/local/lib/python3.8/dist-packages/jax/_src/lax/control_flow/loops.py:1606 traced for scan.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<timed exec>:81 (<module>)
<timed exec>:67 (diffusion_loop)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:264 (step)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:472 (step_plms)
/usr/local/lib/python3.8/dist-packages/diffusers/schedulers/scheduling_pndm_flax.py:508 (_get_prev_sample)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

Solution

  • It sounds like you want to do a callback to the host that is impure (i.e. it has a side-effect of saving values to disk) and does not return any values to the runtime. For that, one option is jax.experimental.host_callback.id_tap, discussed in the docs here.

    For example:

    import jax
    from jax.experimental import host_callback as hcb
    
    def callback(value, transforms):
      # do something
      print(f"callback: {value}")
    
    def diffusion_loop(i, args):
      hcb.id_tap(callback, i)
      return args
    
    args = (1, 2)
    result, _ = jax.lax.fori_loop(0, 5, diffusion_loop, args)
    
    callback: 0
    callback: 1
    callback: 2
    callback: 3
    callback: 4