Search code examples
pythonjax

Local Jax variable not updating in `jit`ted function, but updating in standard?


So, I've got some code and I could really use help deciphering the behavior and how to get it to do what I want.

See my code as follows:

from typing import Callable, List

import chex
import jax.numpy as jnp
import jax

Weights = List[jnp.ndarray]


@chex.dataclass(frozen=True)
class Model:
    mult: Callable[
        [jnp.ndarray],
        jnp.ndarray
    ]

    jitted_mult: Callable[
        [jnp.ndarray],
        jnp.ndarray
    ]

    weight_updater: Callable[
        [jnp.ndarray], None
    ]


def create_weight():
    return jnp.ones((2, 5))


def wrapper():
    weights = create_weight()

    def mult(input_var):
        return weights.dot(input_var)

    @jax.jit
    def jitted_mult(input_var):
        return weights.dot(input_var)

    def update_locally_created(new_weights):
        nonlocal weights
        weights = new_weights
        return weights

    return Model(
        mult=mult,
        jitted_mult=jitted_mult,
        weight_updater=update_locally_created
    )


if __name__ == '__main__':
    tester = wrapper()
    to_mult = jnp.ones((5, 2))
    for i in range(5):
        print(jnp.sum(tester.mult(to_mult)))
        print(jnp.sum(tester.jitted_mult(to_mult)))

        if i % 2 == 0:
            tester.weight_updater(jnp.zeros((2, 5)))
        else:
            tester.weight_updater(jnp.ones((2, 5)))

        print("*" * 10)

TL;DR I'm defining some "weights" within a function closure, and I'm trying to modify the weights via a nonlocal. The problem seems to be that the jit-ted version (jitted_mult of the function doesn't recognize the "updated" weights, whereas the non-jit function (mult) does.

What can I do to make it recognize the update? I think that I might be able to do what Build your own Haiku does, but that seems like a lot of work for an experiment


Solution

  • This is working as expected: the reason it's not respecting the update is because your function is not pure (see JAX Sharp Bits: Pure Functions). In your case, the function is not pure because the output depends on an input that is not explicitly passed to the function. This violates the assumptions made by jit and other JAX transformations, which leads to unexpected behavior.

    To fix it I would make this implicit input explicit, so that your function is pure. It might look something like this:

    def wrapper():
        def mult(input_var, weights):
            return weights.dot(input_var)
    
        @jax.jit
        def jitted_mult(input_var, weights):
            return weights.dot(input_var)
    
        return Model(
            mult=mult,
            jitted_mult=jitted_mult,
            weight_updater=None
        )
    
    
    if __name__ == '__main__':
        tester = wrapper()
        to_mult = jnp.ones((5, 2))
        weights = create_weight()
        for i in range(5):
            print(jnp.sum(tester.mult(to_mult, weights)))
            print(jnp.sum(tester.jitted_mult(to_mult, weights)))
    
            if i % 2 == 0:
                weights = jnp.zeros((2, 5))
            else:
                weights = jnp.ones((2, 5))
    
            print("*" * 10)