Search code examples
pythonjaxhaikuacme-deepmind

acme error - AttributeError: module 'jax' has no attribute 'linear_util'


I am using acme framework to run some experiments, and I installed acme based on documentation. However, I have attribute error that raised likely from JAX, HAIKU, and when I looked into github issue, there was no solution given at this time. Can anyone take a look what package dependecy caused this issue?

my venv spec:

here is my venv spec

dm-acme                      0.4.0
dm-control                   0.0.364896371
dm-env                       1.6
dm-haiku                     0.0.10
dm-launchpad                 0.5.0
dm-reverb                    0.7.0
dm-tree                      0.1.8
acme                         2.10.0
dm-acme                      0.4.0
jax                          0.4.26
jaxlib                       0.4.26+cuda12.cudnn89
python -V                    Python 3.9.5

error details:

File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in from acme.agents.jax import dqn File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/init.py", line 18, in from acme.agents.jax.dqn.actor import behavior_policy File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in from acme.agents.jax import actor_core as actor_core_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in from acme.jax import networks as networks_lib File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/init.py", line 18, in from acme.jax.networks.atari import AtariTorso File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/atari.py", line 29, in from acme.jax.networks import base File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/base.py", line 24, in import haiku as hk File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/init.py", line 20, in from haiku import experimental File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/experimental/init.py", line 34, in from haiku._src.dot import abstract_to_dot File "/data/acme/acme_venv_new/lib/python3.9/site-packages/haiku/_src/dot.py", line 163, in @jax.linear_util.transformation File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 54, in getattr raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax' has no attribute 'linear_util'

seems it raised from haiku and JAX, how this can be fixed? any quick thoughts?

updated attempt

based on @jakevdp suggestion, I reinstalled jax, jaxlib, but now I am getting this error again:

Traceback (most recent call last):
  File "/data/acme/examples/baselines/rl_discrete/run_dqn.py", line 18, in <module>
    from acme.agents.jax import dqn
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/__init__.py", line 18, in <module>
    from acme.agents.jax.dqn.actor import behavior_policy
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/dqn/actor.py", line 20, in <module>
    from acme.agents.jax import actor_core as actor_core_lib
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/agents/jax/actor_core.py", line 22, in <module>
    from acme.jax import networks as networks_lib
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/__init__.py", line 45, in <module>
    from acme.jax.networks.multiplexers import CriticMultiplexer
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/networks/multiplexers.py", line 20, in <module>
    from acme.jax import utils
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/acme/jax/utils.py", line 190, in <module>
    devices: Optional[Sequence[jax.xla.Device]] = None,
  File "/data/acme/acme_venv_new/lib/python3.9/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'xla'

here is my pip freeze list on this public gist: acme pip list

I looked into this github issue: jax xla attribute issue

@jakevdp, any updated comment or possible workaround for this jax.xla issue? thanks


Solution

  • jax.linear_util was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.

    It sounds like you have too new a JAX version for the framework code you are using. I'd try installing an older version; e.g.

    pip install --upgrade "jax[cuda12_pip]<0.4.24" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    See JAX installation for more installation options.

    If you're hoping to update the framework code for compatibility with more recent JAX versions, you might find replacements for previous functionality in jax.extend.linear_util.