Search code examples
pythonjaxequinox

Why is custom pytree 'aux_data' traced after jax.jit() for jnp.array but not for np.array?


I am trying to understand how pytrees work and registered my own class as a pytree. I noticed that if the aux_data in the pytree is a jax.numpy.ndarray the auxilliary data is subsequently traced and returned as a Traced<ShapedArray(...)>.... However, if the aux_data is a numpy.ndarray (i.e. not JAX array), then it is not traced and returns an array from a jit tranformed function.

Now, I am aware of the tracing that happens during the jax.jit() transformation, but I do not understand why, on the level of pytrees, this results in the behaviour described above.

Here is an example to reproduce this behaviour (multiplying both the aux_data and the tree leaves by two, which may be a problem in itself after JIT transformation...?). I have used the custom pytree implementations of accepted libraries (equinox and simple_pytree) for comparison, and they all give the same result, so that I am very sure that this is not a bug but a feature that I am trying to understand.

import jax
from jax.tree_util import tree_structure, tree_leaves
import numpy as np

def get_pytree_impl(base):
    if base == "equinox":
        import equinox as eqx
        Module = eqx.Module
        static_field = eqx.static_field
    elif base == "simple_pytree":
        from simple_pytree import Pytree, static_field
        Module = Pytree
    elif base == "dataclasses":
        from dataclasses import dataclass, field
        @dataclass
        class Module():
            pass
        static_field = field
    
    class PytreeImpl(Module):
        x: jax.numpy.ndarray
        y: jax.numpy.ndarray = static_field()

        def __init__(self, x, y):
            self.x = x
            self.y = y

    if base == 'dataclasses':
        from jax.tree_util import register_pytree_node
        
        def flatten(ptree):
            return ((ptree.x,), ptree.y)
        
        def unflatten(aux_data, children):
            return PytreeImpl(*children, aux_data)

        register_pytree_node(PytreeImpl, flatten, unflatten)
        
    return PytreeImpl

def times_two(ptree):
    return type(ptree)(ptree.x*2, ptree.y*2)

times_two_jitted = jax.jit(times_two)

bases = ['dataclasses', 'equinox', 'simple_pytree']
for base in bases:
    print("========  " + base + "  ========")
    for lib_name, array_lib in zip(['jnp', 'np'], [jax.numpy, np]):
        print("====  " + lib_name)
        PytreeImpl = get_pytree_impl(base)
        x = jax.numpy.array([1,2])
        y = array_lib.array([3,4])
        input_tree = PytreeImpl(x, y)
        for tag, pytree in zip(["input", "no_jit", "jit"],[input_tree, times_two(input_tree), times_two_jitted(input_tree)]):
            print(f' {tag}:')
            print(f'\t Structure: {tree_structure(pytree)}')
            print(f'\t Leaves: {tree_leaves(pytree)}')

This produces the follwing, where dataclasses is my naive custom implementation of a pytree:

========  dataclasses  ========
====  jnp
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[[3 4]], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
====  np
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[[3 4]], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[[6 8]], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
========  equinox  ========
====  jnp
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Array([3, 4], dtype=int32),)], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Array([6, 8], dtype=int32),)], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>,)], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
====  np
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([3, 4]),)], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([6, 8]),)], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[('x',), ('y',), (array([6, 8]),)], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
========  simple_pytree  ========
====  jnp
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Array([3, 4], dtype=int32), '_pytree__initialized': True})], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Array([6, 8], dtype=int32), '_pytree__initialized': True})], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': Traced<ShapedArray(int32[2])>with<DynamicJaxprTrace(level=1/0)>, '_pytree__initialized': True})], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
====  np
 input:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([3, 4]), '_pytree__initialized': True})], [*]))
     Leaves: [Array([1, 2], dtype=int32)]
 no_jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([6, 8]), '_pytree__initialized': True})], [*]))
     Leaves: [Array([2, 4], dtype=int32)]
 jit:
     Structure: PyTreeDef(CustomNode(PytreeImpl[(('x',), {'y': array([6, 8]), '_pytree__initialized': True})], [*]))
     Leaves: [Array([2, 4], dtype=int32)]

I ran this example using Python 3.12.1 with equinox 0.11.4 jax 0.4.28 jaxlib 0.4.28 simple-pytree 0.1.5


Solution

  • From the JAX docs:

    When defining unflattening functions, in general children should contain all the dynamic elements of the data structure (arrays, dynamic scalars, and pytrees), while aux_data should contain all the static elements that will be rolled into the treedef structure.

    aux_data in a pytree flattening must contain static elements, and static elements must be hashable and immutable. Neither np.ndarray nor jax.Array satisfy this, so they should not be included in aux_data. If you do include such values in aux_data, you'll get unsupported, poorly-defined behavior.

    With that background: the answer to your question of why you're seeing the results you're seeing is that you are defining your pytrees incorrectly. If you define aux_data to only contain static (i.e. hashable and immutable) attributes, you will no longer see this behavior.