TL;DR: I create a new instance of my equinox.Module
model and fit it using Optax. Everything works fine. When I create a new instance of the same model and try to fit it from scratch, using the same code, same initial values, same everything, I get:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
...somewhere deep in Optax code. My code doesn't compare any arrays. The error message doesn't show where exactly the comparison happens. What's wrong?
# 1. Import dependencies.
import jax; jax.config.update("jax_enable_x64", True)
import jax.numpy as np, jax.random as rnd, equinox as eqx
import optax
# 2. Define loss function. I'm fairly confident this is correct.
def npdf(x, var):
return np.exp(-0.5 * x**2 / var) / np.sqrt(2 * np.pi * var)
def mixpdf(x, ps, vars):
return ps.dot(npdf(x, vars))
def loss(model, series):
weights, condvars = model(series)
return -jax.vmap(
lambda x, vars: np.log(mixpdf(x, weights, vars))
)(series[1:], condvars[:-1]).mean()
# 3. Define recurrent neural network.
class RNNCell(eqx.Module):
bias: np.ndarray
Wx: np.ndarray
Wh: np.ndarray
def __init__(self, ncomp: int, n_in: int=1, *, key: np.ndarray):
k1, k2, k3 = rnd.split(key, 3)
self.bias = rnd.uniform(k1, (ncomp, ))
self.Wx = rnd.uniform(k2, (ncomp, n_in))
self.Wh = 0.9 * rnd.uniform(k3, (ncomp, ))
def __call__(self, vars_prev, obs):
vars_new = self.bias + self.Wx @ obs + self.Wh * vars_prev
return vars_new, vars_new
class RNN(eqx.Module):
cell: RNNCell
logits: np.ndarray
vars0: np.ndarray = eqx.field(static=True)
def __init__(self, vars0: np.ndarray, n_in=1, *, key: np.ndarray):
self.vars0 = np.array(vars0)
K = len(self.vars0)
self.cell = RNNCell(K, n_in, key=key)
self.logits = np.zeros(K)
def __call__(self, series: np.ndarray):
_, hist = jax.lax.scan(self.cell.__call__, self.vars0, series**2)
return jax.nn.softmax(self.logits), abs(hist)
def condvar(self, series):
weights, variances = self(series)
return variances @ weights
def predict(self, series: np.ndarray):
return self.condvar(series).flatten()[-1]
# 4. Training/fitting code.
def fit(model, logret, nepochs: int, optimizer, loss):
loss_and_grad = eqx.filter_value_and_grad(loss)
@eqx.filter_jit
def make_step(model, opt_state):
loss_val, grads = loss_and_grad(model, logret)
updates, opt_state = optimizer.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss_val, model, opt_state
opt_state = optimizer.init(model)
for epoch in range(nepochs):
loss_val, model, opt_state = make_step(model, opt_state)
print("Works!")
return model
def experiment():
series = rnd.normal(rnd.PRNGKey(8), (100, 1))
model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
return fit(model, series, 100, optax.adam(0.01), loss)
# 5. Run the exact same code twice.
experiment() # 1st call, works
experiment() # 2nd call, error
> python my_RNN.py
Works!
Traceback (most recent call last):
File "/Users/forcebru/test/my_RNN.py", line 75, in <module>
experiment() # 2nd call, error
^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 72, in experiment
return fit(model, series, 100, optax.adam(0.01), loss)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 65, in fit
loss_val, model, opt_state = make_step(model, opt_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 206, in __call__
return self._call(False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_module.py", line 935, in __call__
return self.__func__(self.__self__, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 200, in _call
out = self._cached(dynamic_donate, dynamic_nodonate, static)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
infer_params_fn(*args, **kwargs)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 495, in common_infer_params
jaxpr, consts, out_shardings, out_layouts_flat, attrs_tracked = _pjit_jaxpr(
^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1150, in _pjit_jaxpr
jaxpr, final_consts, out_type, attrs_tracked = _create_pjit_jaxpr(
^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 350, in memoized_fun
ans = call(fun, *args)
^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 1089, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/profiler.py", line 336, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2314, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts, attrs_tracked = trace_to_subjaxpr_dynamic(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py", line 2336, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/linear_util.py", line 192, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/equinox/_jit.py", line 49, in fun_wrapped
out = fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/test/my_RNN.py", line 59, in make_step
updates, opt_state = optimizer.update(grads, opt_state)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/combine.py", line 59, in update_fn
updates, new_s = fn(updates, s, params, **extra_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/base.py", line 337, in update
return tx.update(updates, state, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/optax/_src/transform.py", line 369, in update_fn
mu_hat = bias_correction(mu, b1, count_inc)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 248, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 136, in _python_pjit_helper
infer_params_fn(*args, **kwargs)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/api.py", line 325, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/pjit.py", line 491, in common_infer_params
canonicalized_in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<string>", line 4, in __eq__
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 745, in __bool__
check_bool_conversion(self)
File "/Users/forcebru/.pyenv/versions/3.12.1/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/site-packages/jax/_src/core.py", line 662, in check_bool_conversion
raise ValueError("The truth value of an array with more than one element is "
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
File "<string>", line 4, in __eq__
, which doesn't help.mu_hat = bias_correction(mu, b1, count_inc)
in Optax code, but as far as I understand, it doesn't compare any arrays.Is there a bug in my model definition (RNNCell
or RNN
)? Did I implement the training loop wrong? I basically copied it straight from Equinox docs, so it should be fine. Why does it work when I call experiment()
the first time, but not the second?
It appears this is a bug in equinox
. The function _process_in_axis_resources
is decorated in functools.lru_cache
, meaning that all inputs are checked for equality with arguments from the previous call. On the second run, this triggers a call to equinox.Module.__eq__
, which raises the error. You can see this problem by doing the equality check directly:
model = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
model2 = RNN([0.4, 0.6, 0.8], key=rnd.PRNGKey(8))
model == model2
# ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
I would suggest reporting this bug at https://github.com/patrick-kidger/equinox/issues
You could probably work around this issue by not storing a numpy array (vars0
) as a static attribute. I suspect that equinox assumes that all static attributes are hashable, and numpy arrays are not.
Edit: I just checked, and changing this:
vars0: np.ndarray = eqx.field(static=True)
to this:
vars0: np.ndarray
resolves the issue.
Edit 2: Indeed it looks like static fields in equinox must be hashable, so this is not an equinox bug but rather a usage error (see the discussion at https://github.com/patrick-kidger/equinox/issues/154#issuecomment-1561735995). You might try storing vars0
as a tuple (which is hashable) rather than an array (which isn't).