Search code examples
pythonjax

JAX update breaks working code of linear algebra solver


I am dealing with an issue related to an update of jax. I have a library which is supposed to solve a system of linear equations using the bicgstab algorithm.

The solver is implemented as follows:

def bicgstabsolver(A, b, eps):
  '''Returns the loop initialization and iteration functions.'''

  def init(z, b, x0):
    '''Forms the args that will be used to update stuff.'''

    x = x0
    r = b - A(x, z)
    rstrich = r
    v = vecfield.zeros(b.shape)
    p = vecfield.zeros(b.shape)
    alpha = 1
    rho = 1
    omega = 1

    term_err = eps * vecfield.norm(b)

    return x, r, rstrich, v, p, alpha, rho, omega, term_err
    
  @jax.jit
  def iter(x, r, rstrich, v, p, alpha, rho, omega, z):
    '''Run the iteration loop `n` times.'''
    rhoold = rho
    rho = vecfield.dot(vecfield.conj(rstrich),r)
    beta = (rho / rhoold) * (alpha / omega)
    p = r + beta * (p - omega * v)
    v = A(p,z)
    alpha = rho / vecfield.dot(vecfield.conj(rstrich),v)
    h = x + alpha * p
    s = r - alpha * v
    t = A(s,z)
    omega = vecfield.dot(vecfield.conj(t),s) / vecfield.dot(vecfield.conj(t),t)
    x = h + omega * s
    r = s - omega * t
    err = vecfield.norm(r)
    return x, r, rstrich, v, p, alpha, rho, omega, err

  return init, iter

The implementation of the VecField class:

import jax.numpy as np
from typing import Any, NamedTuple


class VecField(NamedTuple):
  '''Represents a 3-tuple of arrays.'''
  x: Any
  y: Any
  z: Any

  @property
  def shape(self):
    assert self.x.shape == self.y.shape == self.z.shape
    return self.x.shape

  @property
  def dtype(self):
    assert self.x.dtype == self.y.dtype == self.z.dtype
    return self.x.dtype

  def as_array(self):
    return VecField(*(np.array(a) for a in self))

  def __add__(x, y):
    return VecField(*(a + b for a, b in zip(x, y)))

  def __sub__(x, y):
    return VecField(*(a - b for a, b in zip(x, y)))

  def __mul__(x, y):
    return VecField(*(a * b for a, b in zip(x, y)))

  def __rmul__(y, x):
    return VecField(*(x * b for b in y))


def zeros(shape):
  return VecField(*(np.zeros(shape, np.complex128) for _ in range(3)))

def ones(shape):
  return VecField(*(np.ones(shape, np.complex128) for _ in range(3)))

# TODO: Check if this hack is still necessary to obtain good performance.
def dot(x, y):
  z = VecField(*(a * b for a, b in zip(x, y)))
  return sum(np.sum(np.real(c)) + 1j * np.sum(np.imag(c)) for c in z)


def norm(x):
  return np.sqrt(sum(np.square(np.linalg.norm(a)) for a in x))


def conj(x):
  return VecField(*(np.conj(a) for a in x))


def real(x):
  return VecField(*(np.real(a) for a in x))


def from_tuple(x):
  return VecField(*(np.reshape(a, (1, 1) + a.shape) for a in x))


def to_tuple(x):
  return tuple(np.reshape(a, a.shape[2:]) for a in x)

The code is running perfectly fine using jax and jaxlib version 0.3.10. However, if I update jax to 0.4.13 it stops working with a cryptic error:

File "***", line 66, in iter
    p = r + beta * (p - omega * v)
  File "***/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 791, in op
    return getattr(self.aval, f"_{name}")(self, *args)
  File "***/python3.8/site-packages/jax/_src/numpy/array_methods.py", line 260, in deferring_binary_op
    raise TypeError(f"unsupported operand type(s) for {opchar}: "
jax._src.traceback_util.UnfilteredStackTrace: TypeError: unsupported operand type(s) for *: 'DynamicJaxprTracer' and 'VecField'

I have no clue so far how to migrate this code to be compatible with the newer version of jax. Probably I'm missing something very obvious. Any help would be greatly appreciated!


Solution

  • It looks like JAX's array __mul__ methods are raising a TypeError on unsupported input rather than returning NotImplemented, which means that omega * v is not correctly dispatching to v.__rmul__().

    This is a bug in JAX: I would suggest reporting this in a new issue at http://github.com/google/jax/issues/

    In the meantime, you should be able to work around this by making sure that every time you operate between a VecField by a JAX array, the VecField appears on the left of the operand; e.g. change this:

    p = r + beta * (p - omega * v)
    

    to this:

    p = (p - v * omega) * beta + r
    

    Edit: it looks like the bug was introduced in https://github.com/google/jax/pull/11234 (meaning it's present in all JAX versions 0.3.14 and newer) and only affects subtypes of builtin collections (which includes NamedTuple).

    Edit 2: this has been fixed in https://github.com/google/jax/pull/17406, which should be part of a future JAX 0.4.16 release.