Search code examples
pythonmachine-learningneural-networkjax

"The truth value of an array with more than one element is ambiguous" when trying to train a new JAX+Equinox model a second time


TL;DR: I create a new instance of my equinox.Module model and fit it using Optax. Everything works fine. When I create a new instance of the same model and try to fit it from scratch, using the same code, same initial values, same everything, I get:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

...somewhere deep in Optax code. My code doesn't compare any arrays. The error message doesn't show where exactly the comparison happens. What's wrong?

Code

# 1. Import dependencies.
import jax; jax.config.update("jax_enable_x64", True)
import jax.numpy as np, jax.random as rnd, equinox as eqx
import optax

# 2. Define loss function. I'm fairly confident this is correct.
def npdf(x, var):
    return np.exp(-0.5 * x**2 / var) / np.sqrt(2 * np.pi * var)

def mixpdf(x, ps, vars):
    return ps.dot(npdf(x, vars))

def loss(model, series):
    weights, condvars = model(series)
    return -jax.vmap(
        lambda x, vars: np.log(mixpdf(x, weights, vars))
    )(series[1:], condvars[:-1]).mean()

# 3. Define recurrent neural network.
class RNNCell(eqx.Module):
    bias: np.ndarray
    Wx: np.ndarray
    Wh: np.ndarray
    def __init__(self, ncomp: int, n_in: int=1, *, key: np.ndarray):
        k1, k2, k3 = rnd.split(key, 3)
        self.bias = rnd.uniform(k1, (ncomp, ))
        self.Wx = rnd.uniform(k2, (ncomp, n_in))
        self.Wh = 0.9 * rnd.uniform(k3, (ncomp, ))

    def __call__(self, vars_prev, obs):
        vars_new = self.bias + self.Wx @ obs + self.Wh * vars_prev
        return vars_new, vars_new

class RNN(eqx.Module):
    cell: RNNCell
    logits: np.ndarray
    vars0: np.ndarray = eqx.field(static=True)

    def __init__(self, vars0: np.ndarray, n_in=1, *, key: np.ndarray):
        self.vars0 = np.array(vars0)
        K = len(self.vars0)
        self.cell = RNNCell(K, n_in, key=key)
        self.logits = np.zeros(K)

    def __call__(self, series: np.ndarray):
        _, hist = jax.lax.scan(self.cell.__call__, self.vars0, series**2)
        return jax.nn.softmax(self.logits), abs(hist)

    def condvar(self, series):
        weights, variances = self(series)
        return variances @ weights

    def predict(self, series: np.ndarray):
        return self.condvar(series).flatten()[-1]

# 4. Training/fitting code.
def fit(model, logret, nepochs: int, optimizer, loss):
    loss_and_grad = eqx.filter_value_and_grad(loss)
    
    @eqx.filter_jit
    def make_step(model, opt_state):
        loss_val, grads = loss_and_grad(model, logret)
        updates, opt_state = optimizer.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss_val, model, opt_state

    opt_state = optimizer.init(model)
    for epoch in range(nepochs):
        loss_val, model, opt_state = make_step(model, opt_state)
    print("Works!")
    return model

def experiment():
    series = rnd.normal(rnd.PRNGKey(8), (100, 1))
    model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
    return fit(model, series, 100, optax.adam(0.01), loss)

# 5. Run the exact same code twice.
experiment() # 1st call, works
experiment() # 2nd call, error

Error message

> python my_RNN.py
Works!
Traceback (most recent call last):
  File "/Users/forcebru/test/my_RNN.py", line 75, in <module>
    experiment() # 2nd call, error
    ^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 72, in experiment
    return fit(model, series, 100, optax.adam(0.01), loss)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 65, in fit
    loss_val, model, opt_state = make_step(model, opt_state)
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_module.py", line 935, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
    infer_params_fn(*args, **kwargs)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 495, in common_infer_params
    jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
                                                                    ^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1150, in _pjit_jaxpr
    jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
                                                   ^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
    ans = call(fun, *args)
          ^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1089, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
                                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/profiler.py", line 336, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2314, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
    out = fun(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/test/my_RNN.py", line 59, in make_step
    updates, opt_state = optimizer.update(grads, opt_state)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/combine.py", line 59, in update_fn
    updates, new_s = fn(updates, s, params, **extra_args)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/base.py", line 337, in update
    return tx.update(updates, state, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/transform.py", line 369, in update_fn
    mu_hat = bias_correction(mu, b1, count_inc)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
                                                                ^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
    infer_params_fn(*args, **kwargs)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
    canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
                                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 4, in __eq__
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 745, in __bool__
    check_bool_conversion(self)
  File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 662, in check_bool_conversion
    raise ValueError("The truth value of an array with more than one element is "
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Problem

  • The error message says File "<string>", line 4, in __eq__, which doesn't help.
  • It refers to the line mu_hat = bias_correction(mu, b1, count_inc) in Optax code, but as far as I understand, it doesn't compare any arrays.
  • It also refers to JAX code that's supposedly responsible for JIT compilation, but this seems outside my control.

Is there a bug in my model definition (RNNCell or RNN)? Did I implement the training loop wrong? I basically copied it straight from Equinox docs, so it should be fine. Why does it work when I call experiment() the first time, but not the second?


Solution

  • It appears this is a bug in equinox. The function _process_in_axis_resources is decorated in functools.lru_cache, meaning that all inputs are checked for equality with arguments from the previous call. On the second run, this triggers a call to equinox.Module.__eq__, which raises the error. You can see this problem by doing the equality check directly:

    model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
    model2 = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
    model == model2
    # ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
    

    I would suggest reporting this bug at https://github.com/patrick-kidger/equinox/issues

    You could probably work around this issue by not storing a numpy array (vars0) as a static attribute. I suspect that equinox assumes that all static attributes are hashable, and numpy arrays are not.

    Edit: I just checked, and changing this:

    vars0: np.ndarray = eqx.field(static=True)
    

    to this:

    vars0: np.ndarray
    

    resolves the issue.

    Edit 2: Indeed it looks like static fields in equinox must be hashable, so this is not an equinox bug but rather a usage error (see the discussion at https://github.com/patrick-kidger/equinox/issues/154#issuecomment-1561735995). You might try storing vars0 as a tuple (which is hashable) rather than an array (which isn't).