I am dealing with an issue related to an update of jax. I have a library which is supposed to solve a system of linear equations using the bicgstab algorithm.
The solver is implemented as follows:
def bicgstabsolver(A, b, eps):
'''Returns the loop initialization and iteration functions.'''
def init(z, b, x0):
'''Forms the args that will be used to update stuff.'''
x = x0
r = b - A(x, z)
rstrich = r
v = vecfield.zeros(b.shape)
p = vecfield.zeros(b.shape)
alpha = 1
rho = 1
omega = 1
term_err = eps * vecfield.norm(b)
return x, r, rstrich, v, p, alpha, rho, omega, term_err
@jax.jit
def iter(x, r, rstrich, v, p, alpha, rho, omega, z):
'''Run the iteration loop `n` times.'''
rhoold = rho
rho = vecfield.dot(vecfield.conj(rstrich),r)
beta = (rho / rhoold) * (alpha / omega)
p = r + beta * (p - omega * v)
v = A(p,z)
alpha = rho / vecfield.dot(vecfield.conj(rstrich),v)
h = x + alpha * p
s = r - alpha * v
t = A(s,z)
omega = vecfield.dot(vecfield.conj(t),s) / vecfield.dot(vecfield.conj(t),t)
x = h + omega * s
r = s - omega * t
err = vecfield.norm(r)
return x, r, rstrich, v, p, alpha, rho, omega, err
return init, iter
The implementation of the VecField class:
import jax.numpy as np
from typing import Any, NamedTuple
class VecField(NamedTuple):
'''Represents a 3-tuple of arrays.'''
x: Any
y: Any
z: Any
@property
def shape(self):
assert self.x.shape == self.y.shape == self.z.shape
return self.x.shape
@property
def dtype(self):
assert self.x.dtype == self.y.dtype == self.z.dtype
return self.x.dtype
def as_array(self):
return VecField(*(np.array(a) for a in self))
def __add__(x, y):
return VecField(*(a + b for a, b in zip(x, y)))
def __sub__(x, y):
return VecField(*(a - b for a, b in zip(x, y)))
def __mul__(x, y):
return VecField(*(a * b for a, b in zip(x, y)))
def __rmul__(y, x):
return VecField(*(x * b for b in y))
def zeros(shape):
return VecField(*(np.zeros(shape, np.complex128) for _ in range(3)))
def ones(shape):
return VecField(*(np.ones(shape, np.complex128) for _ in range(3)))
# TODO: Check if this hack is still necessary to obtain good performance.
def dot(x, y):
z = VecField(*(a * b for a, b in zip(x, y)))
return sum(np.sum(np.real(c)) + 1j * np.sum(np.imag(c)) for c in z)
def norm(x):
return np.sqrt(sum(np.square(np.linalg.norm(a)) for a in x))
def conj(x):
return VecField(*(np.conj(a) for a in x))
def real(x):
return VecField(*(np.real(a) for a in x))
def from_tuple(x):
return VecField(*(np.reshape(a, (1, 1) + a.shape) for a in x))
def to_tuple(x):
return tuple(np.reshape(a, a.shape[2:]) for a in x)
The code is running perfectly fine using jax and jaxlib version 0.3.10. However, if I update jax to 0.4.13 it stops working with a cryptic error:
File "***", line 66, in iter
p = r + beta * (p - omega * v)
File "***/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 791, in op
return getattr(self.aval, f"_{name}")(self, *args)
File "***/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 260, in deferring_binary_op
raise TypeError(f"unsupported operand type(s) for {opchar}: "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'
I have no clue so far how to migrate this code to be compatible with the newer version of jax. Probably I'm missing something very obvious. Any help would be greatly appreciated!
It looks like JAX's array __mul__
methods are raising a TypeError
on unsupported input rather than returning NotImplemented
, which means that omega * v
is not correctly dispatching to v.__rmul__()
.
This is a bug in JAX: I would suggest reporting this in a new issue at http://github.com/google/jax/issues/
In the meantime, you should be able to work around this by making sure that every time you operate between a VecField
by a JAX array, the VecField
appears on the left of the operand; e.g. change this:
p = r + beta * (p - omega * v)
to this:
p = (p - v * omega) * beta + r
Edit: it looks like the bug was introduced in https://github.com/google/jax/pull/11234 (meaning it's present in all JAX versions 0.3.14 and newer) and only affects subtypes of builtin collections (which includes NamedTuple
).
Edit 2: this has been fixed in https://github.com/google/jax/pull/17406, which should be part of a future JAX 0.4.16 release.