Search code examples
pythonjax

Is it possible to obtain values from jax traced arrays with dynamicjaxprtrace level larger than 1 using any of the callback functions?


So I have a program that have multiple functions with its own jax calls and here is the main function:

@partial(jax.jit, static_argnames=("numberOfVoxels",))    
def process_valid_voxels(numberOfVoxels, voxelPositions, voxelLikelihoods, ps, t, M, tmp):
   
    func = lambda tmp_val: process_voxel(tmp_val, voxelPositions, voxelLikelihoods, ps, t, M, tmp)
    ys, likelihoods = jax.vmap(func)(jnp.arange(numberOfVoxels))
   
    return ys, likelihoods

This is the output of ys and likelihoods:

(Pdb) ys
Traced<ShapedArray(int32[3700,3,1])>with<DynamicJaxprTrace(level=3/0)>`
likelihoods
Traced<ShapedArray(float32[3700,7,1])>with<DynamicJaxprTrace(level=3/0)>

I want to get values from traced arrays ys, likelihoods so that I can modify them. I have tried using the io_callback function:

def callback1(x):
    return jax.experimental.io_callback(process_voxel, x , x)
a = callback1(jnp.arange(numberOfVoxels))

but the output is the same except for the shape of the array:

Traced<ShapedArray(int32[3700])>with<DynamicJaxprTrace(level=3/0)>

Solution

  • This is similar to one of JAX's FAQs: How can I convert a tracer to a numpy array?. That answer mentions callbacks, which you use above, but I think you have the wrong mental model of what io_callback does.

    When you run transformed JAX code, there are essentially two stages of execution:

    1. Tracing happens within the Python runtime, using abstract representations of the arrays (tracers) to extract the sequence of operations implied by your code. During tracing in most cases, array values are not available by design.

    2. Execution happens within the XLA runtime once tracing has encoded the sequence of operations to be run. Array values are available during XLA execution, but this is not a Python runtime, and so Python debugging, printing, etc. is not available.

    Your question amounts to "How can I access array values from stage 2 by inserting breakpoints into the runtime during stage 1" The answer is: you can't!

    But you can use callbacks and jax debugging tools to encode an instruction to tell XLA to pause execution and pass its values to some callback function during stage 2 execution and let you interact with array values from within Python.

    One way to do so might look like this:

    @partial(jax.jit, static_argnames=("numberOfVoxels",))    
    def process_valid_voxels(numberOfVoxels, voxelPositions, voxelLikelihoods, ps, t, M, tmp):
        func = lambda tmp_val: process_voxel(tmp_val, voxelPositions, voxelLikelihoods, ps, t, M, tmp)
        ys, likelihoods = jax.vmap(func)(jnp.arange(numberOfVoxels))
        jax.debug.breakpoint()
        return ys, likelihoods
    

    This will tell XLA during stage 2 to pause execution, call back to a Python-side debugging tool, and let you interact with the values there.

    For more intuition on the mental model of JAX program execution, you may find How to think in JAX useful.