I'm trying to reproduce the experiments of the S5 model, https://github.com/lindermanlab/S5, but I encountered some issues when solving the environment. When I'm running the shell script./run_lra_cifar.sh
, I get the following error
Traceback (most recent call last):
File "/Path/S5/run_train.py", line 3, in <module>
from s5.train import train
File "/Path/S5/s5/train.py", line 7, in <module>
from .train_helpers import create_train_state, reduce_lr_on_plateau,\
File "/Path/train_helpers.py", line 6, in <module>
from flax.training import train_state
File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
from . import core
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
from .axes_scan import broadcast
File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)
I'm running this on an RTX4090 and my CUDA version is 11.8. My jax version is 0.4.25 and jaxlib version is 0.4.25+cuda11.cudnn86
I first tried to install the dependencies using the author's
pip install -r requirements_gpu.txt
However, this doesn't seem to work in my case since I can't evenimport jax
. So I installed jax according to the instructions on https://jax.readthedocs.io/en/latest/installation.html
by typing
pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
So far I've tried:
Does anyone know what could be wrong? Any help is appreciated
jax.linear_util
was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.
It appears that flax
is the source of the linear_util
import, meaning that you are using an older flax version with a newer jax version.
To fix your issue, you'll either need to install an older version of JAX which still has jax.linear_util
, or update to a newer version of flax
which is compatible with more recent JAX versions.