I am trying to use @jit
with nested function, having a problem.
I have a class One
that take in another class Plant
with a method func
.
I would like to call this method jitted func
from One
.
I think that I followed the FAQ of JAX, "How to use jit with methods?" section.
https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods
However, I encountered an error saying that
TypeError: One.__init__() got multiple values for argument 'plant'
.
Would anyone tell me how to solve this?
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import numpy as np
from functools import partial
from jax import tree_util
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.x,) # arrays / dynamic values
aux_data = {'plant':self.plant} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
import pdb; pdb.set_trace();
return cls(*children, **aux_data)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z,kk):
self.z =z
@jit
def func(self,y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z,) # arrays / dynamic values
aux_data = None # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))
The last line gives me an error described above.
You have issues in the tree_flatten
and tree_unflatten
code in both classes.
One._tree_flatten
treats plant
as static data, but it is not: it is a pytree that has non-static elements.One._tree_unflatten
instantiates One
with arguments in the wrong order, leading to the error you're seeingPlant.__init__
does nothing with the kk
argument.Plant._tree_unflatten
is missing the aux_data
argument, and fails to pass the kk
argument to Plant.__init__
With these issues fixed, your code executes without error:
class One:
def __init__(self, plant,x):
self.plant = plant
self.x = x
@jit
def call_plant_func(self,y):
out = self.plant.func(y) + self.x
return out
def _tree_flatten(self):
children = (self.plant, self.x)
aux_data = None
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(One,
One._tree_flatten,
One._tree_unflatten)
class Plant:
def __init__(self, z, kk):
self.kk = kk
self.z =z
@jit
def func(self, y):
y = y + self.z
return y
def _tree_flatten(self):
children = (self.z, self.kk)
aux_data = None
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children)
tree_util.register_pytree_node(Plant,
Plant._tree_flatten,
Plant._tree_unflatten)
plant = Plant(5,2)
one = One(plant,2)
print(one.call_plant_func(10))