Search code examples
pythonpython-decoratorsjax

Using Jax Jit on a method as decorator versus applying jit function directly


I guess most people familiar with jax have seen this example in the documentation and know that it does not work:

import jax.numpy as jnp
from jax import jit

class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  @jit  # <---- How to do this correctly?
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y


c = CustomClass(2, True)
c.calc(3)  

3 workarounds are mentioned, but it appears that applying jit as a function directly, rather than a decorator works fine as well. That is, JAX does not complain about not knowing how to deal with the CustomClass type of self:

import jax.numpy as jnp
from jax import jit

class CustomClass:
  def __init__(self, x: jnp.ndarray, mul: bool):
    self.x = x
    self.mul = mul

  # No decorator here !
  def calc(self, y):
    if self.mul:
      return self.x * y
    return y


c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
6 # works fine!

Although not documented (which it maybe should be?), this appears to function identical to marking self as static via @partial(jax.jit, static_argnums=0), in that changing self does nothing for subsequent calls, i.e.:

c = CustomClass(2, True)
jitted_calc = jit(c.calc)
print(jitted_calc(3))
c.mul = False 
print(jitted_calc(3))
6
6 # no update

So I originally assumed that decorators in general might just deal with self as a static parameter when applying them directly. Because the method might be saved to another variable with a specific instance (copy) of self. As a sanity check, I checked if non-jit decorators indeed do this as well, but this appears not to be the case, as the below non-jit "decorated" function happily deals with changes to self:

def decorator(func):
    def wrapper(*args, **kwargs):
        x = func(*args, **kwargs)
        return x
    return wrapper

custom = CustomClass(2, True)
decorated_calc = decorator(custom.calc)
print(decorated_calc(3))
custom.mul = False
print(decorated_calc(3))
6
3

I saw some other questions about applying decorators directly as functions versus decorator style (e.g. here and here), and there it is mentioned there is a slight difference in the two versions, but this should almost never matter. I am left wondering what it is about the jit decorator that makes these versions behave so differently, in that JAX.jit cán deal with the self type if not in decorated style. If anyone has an answer, that would be much appreciated.


Solution

  • Decorators have nothing to do with static arguments: static arguments are a concept specific to jax.jit.

    Backing up, you should keep in mind that whenever jax.jit compiles a function, it caches the compilation artifact based on several quantites, including:

    1. the ID of the function or callable being compiled
    2. the static attributes of any non-static arguments, such as shape and dtype
    3. the hash of any arguments marked static via static_argnums or static_argnames
    4. the value of any global configurations that would affect outputs

    With this in mind, let's examine this snippet:

    c = CustomClass(2, True)
    jitted_calc = jit(c.calc)
    print(jitted_calc(3))
    c.mul = False 
    print(jitted_calc(3))
    

    the reason that jitted_calc doesn't update when you update attributes of c is because nothing related to the cache key has changed: (1) the function ID is the same, (2) the shape and dtype of the argument is unchanged, (3) there are no static arguments, (4) no global configurations have changed. Thus the previous cached compilation artifact (with the previous value of mul) is executed again. This is the primary reason I didn't mention this strategy in the doc you linked to: it's rarely the behavior that users would want.

    This approach of wrapping the bound method in JIT is incidentally similar to wrapping the method definition with @partial(jit, static_argnums=0), but the details are not the same: in the static_argnums version, self is marked as a static argument, and so its hash becomes part of the JIT cache. The default __hash__ method for a class is simply based on the ID of the instance, and so changing c.mul does not change the hash, and does not trigger re-compilation. You can see an example of how to rectify this under Strategy 2 in the doc you linked to: basically, define appropriate __hash__ and __eq__ methods for the class:

    class CustomClass:
      def __init__(self, x: jnp.ndarray, mul: bool):
        self.x = x
        self.mul = mul
    
      @partial(jit, static_argnums=0)
      def calc(self, y):
        if self.mul:
          return self.x * y
        return y
    
      def __hash__(self):
        return hash((self.x, self.mul))
    
      def __eq__(self, other):
        return (isinstance(other, CustomClass) and
                (self.x, self.mul) == (other.x, other.mul))
    

    In your last example, you define this:

    def decorator(func):
        def wrapper(*args, **kwargs):
            x = func(*args, **kwargs)
            return x
        return wrapper
    

    This code does not use jax.jit at all. The fact that changes to c.mul lead to changes in outputs has nothing to do with decorator syntax, but rather has to do with the fact that there is no JIT cache in play here.

    I hope that's all clear!