Search code examples
nestedjitjax

JAX @jit for nested class method


I am trying to use @jit with nested function, having a problem. I have a class One that take in another class Plant with a method func. I would like to call this method jitted func from One. I think that I followed the FAQ of JAX, "How to use jit with methods?" section. https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods However, I encountered an error saying that TypeError: One.__init__() got multiple values for argument 'plant'. Would anyone tell me how to solve this?

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util

class One:
    def __init__(self, plant,x):
        self.plant = plant
        self.x = x
    
    @jit
    def call_plant_func(self,y):
        out = self.plant.func(y) + self.x
        return out
    
    def _tree_flatten(self):
        children = (self.x,)  # arrays / dynamic values
        aux_data = {'plant':self.plant}  # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        import pdb; pdb.set_trace();
        return cls(*children, **aux_data)
        
tree_util.register_pytree_node(One,
                               One._tree_flatten,
                               One._tree_unflatten)    
    
class Plant:
    def __init__(self, z,kk):
        self.z =z
    
    @jit
    def func(self,y):
        y = y + self.z
        return y
    
    def _tree_flatten(self):
        children = (self.z,)  # arrays / dynamic values
        aux_data = None # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, children):
        return cls(*children)
   
tree_util.register_pytree_node(Plant,
                               Plant._tree_flatten,
                               Plant._tree_unflatten)

plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))

The last line gives me an error described above.


Solution

  • You have issues in the tree_flatten and tree_unflatten code in both classes.

    • One._tree_flatten treats plant as static data, but it is not: it is a pytree that has non-static elements.
    • One._tree_unflatten instantiates One with arguments in the wrong order, leading to the error you're seeing
    • Plant.__init__ does nothing with the kk argument.
    • Plant._tree_unflatten is missing the aux_data argument, and fails to pass the kk argument to Plant.__init__

    With these issues fixed, your code executes without error:

    class One:
        def __init__(self, plant,x):
            self.plant = plant
            self.x = x
        
        @jit
        def call_plant_func(self,y):
            out = self.plant.func(y) + self.x
            return out
        
        def _tree_flatten(self):
            children = (self.plant, self.x)
            aux_data = None
            return (children, aux_data)
    
        @classmethod
        def _tree_unflatten(cls, aux_data, children):
            return cls(*children)
            
    tree_util.register_pytree_node(One,
                                   One._tree_flatten,
                                   One._tree_unflatten)    
        
    class Plant:
        def __init__(self, z, kk):
            self.kk = kk
            self.z =z
        
        @jit
        def func(self, y):
            y = y + self.z
            return y
        
        def _tree_flatten(self):
            children = (self.z, self.kk)
            aux_data = None
            return (children, aux_data)
    
        @classmethod
        def _tree_unflatten(cls, aux_data, children):
            return cls(*children)
       
    tree_util.register_pytree_node(Plant,
                                   Plant._tree_flatten,
                                   Plant._tree_unflatten)
    
    plant = Plant(5,2)
    one = One(plant,2)
    print(one.call_plant_func(10))