I'm developing a code using JAX, and I wanted to JIT some parts of that had big loops. I didn't want the code to be unrolled so I used fori_loop, but I'm getting an error and can't figure out what I am doing wrong.
The error is:
self.arr = self.arr.reshape(new_shape+new_shape)
TypeError: 'aval_method' object is not callable
I was able to reduce the code to the following:
import jax.numpy as jnp
import jax
class UB():
def __init__(self, arr, new_shape):
self.arr = arr
self.shape = new_shape
if type(arr) is not object:
self.arr = self.arr.reshape(new_shape+new_shape)
def _tree_flatten(self):
children = (self.arr,) # arrays / dynamic values
aux_data = {
'new_shape': self.shape
} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
class UM():
def __init__(self, arr, r=None):
self.arr = arr
self.r = tuple(r)
def _tree_flatten(self):
children = (self.arr,) # arrays / dynamic values
aux_data = {
'r': self.r
} # static values
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*children, **aux_data)
for C in [UB, UM]:
jax.tree_util.register_pytree_node(
C,
C._tree_flatten,
C._tree_unflatten,
)
def s_w(ub, ums):
e = jnp.identity(2)
u = UM(e, [2])
ums[0] = u
return ub, ums
def s_c(t, uns):
n = 20
ums = []
for un in uns:
ums.append(UM(un, [2]))
tub = UB(t.arr, t.r)
s_loop_body = lambda i,x: s_w( ub=x[0], ums=x[1])
tub, ums = jax.lax.fori_loop(0, n, s_loop_body, (tub, ums))
# for i in range(n):
# tub, ums = s_loop_body(i, (tub, ums))
return jnp.array([u.arr.flatten() for u in ums])
uns = jnp.array([jnp.array([1, 2, 3, 4]) for _ in range(6)])
t = UM(jnp.array([1, 0, 0, 1]), r=[2])
uns = s_c(t, uns)
Has anyone encountered this issue or can explain how to fix it?
The issue is discussed here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization
Namely, in JAX pytrees are used as general containers, and are sometimes initialized with abstract values or other place-holders, and so you cannot assume that arguments to a custom PyTree will be of array type. You might account for this by doing something like the following:
class UB():
def __init__(self, arr, new_shape):
self.arr = arr
self.shape = new_shape
if isinstance(arr, jnp.ndarray):
self.arr = self.arr.reshape(new_shape+new_shape)
When I run your code with this modification, it gets past the error you asked about, but unfortunately does trigger another error due to the body function of the fori_loop
not having a valid signature (namely, the arr
attributes of the ums
have different shapes on input and output, which is not supported by fori_loop
).
Hopefully this gets you on the path toward working code!