Search code examples
pythonmachine-learningdeep-learningimporterrorjax

Cannot import name 'linear_util' from 'jax'


I'm trying to reproduce the experiments of the S5 model, https://github.com/lindermanlab/S5, but I encountered some issues when solving the environment. When I'm running the shell script./run_lra_cifar.sh, I get the following error

Traceback (most recent call last):
  File "/Path/S5/run_train.py", line 3, in <module>
    from s5.train import train
  File "/Path/S5/s5/train.py", line 7, in <module>
    from .train_helpers import create_train_state, reduce_lr_on_plateau,\
  File "/Path/train_helpers.py", line 6, in <module>
    from flax.training import train_state
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/__init__.py", line 19, in <module>
    from . import core
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/__init__.py", line 15, in <module>
    from .axes_scan import broadcast
  File "/Path/miniconda3/lib/python3.12/site-packages/flax/core/axes_scan.py", line 22, in <module>
    from jax import linear_util as lu
ImportError: cannot import name 'linear_util' from 'jax' (/Path/miniconda3/lib/python3.12/site-packages/jax/__init__.py)

I'm running this on an RTX4090 and my CUDA version is 11.8. My jax version is 0.4.25 and jaxlib version is 0.4.25+cuda11.cudnn86

I first tried to install the dependencies using the author's

pip install -r requirements_gpu.txt

However, this doesn't seem to work in my case since I can't evenimport jax. So I installed jax according to the instructions on https://jax.readthedocs.io/en/latest/installation.html by typing

pip install --upgrade pip
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

So far I've tried:

  1. Using a older GPU(3060 and 2070)
  2. Downgrading python to 3.9

Does anyone know what could be wrong? Any help is appreciated


Solution

  • jax.linear_util was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.

    It appears that flax is the source of the linear_util import, meaning that you are using an older flax version with a newer jax version.

    To fix your issue, you'll either need to install an older version of JAX which still has jax.linear_util, or update to a newer version of flax which is compatible with more recent JAX versions.