Search code examples
pythonjax

ERROR: No matching distribution found for jaxlib==0.1.67


I need jaxlib==0.1.67 for a project I'm working on, but I can't downgrade. At the moment I have jaxlib==0.1.75 and my program keeps failing due to an error I can't find a solution to either. I compared all versions of the important packages to another machines versions where my programs runs with no problems and the only difference is the jaxlib version (it's still 0.1.67 on the machine where it runs). I suspect that jaxlib is the issue because the error I get when it's not 0.1.67 is the following:

    from haiku import data_structures
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/data_structures.py", line 17, in <module>
    from haiku._src.data_structures import to_immutable_dict
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/data_structures.py", line 30, in <module>
    from haiku._src import utils
  File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/utils.py", line 24, in <module>
    import jax
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/__init__.py", line 108, in <module>
    from .experimental.maps import soft_pmap
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/experimental/maps.py", line 25, in <module>
    from .. import numpy as jnp
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/__init__.py", line 16, in <module>
    from . import fft
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/fft.py", line 17, in <module>
    from jax._src.numpy.fft import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/fft.py", line 19, in <module>
    from jax import lax
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/lax/__init__.py", line 334, in <module>
    from jax._src.lax.parallel import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/lax/parallel.py", line 36, in <module>
    from jax._src.numpy import lax_numpy
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 51, in <module>
    from jax import ops
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/ops/__init__.py", line 16, in <module>
    from jax._src.ops.scatter import (
  File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 31, in <module>
    from typing import EllipsisType
ImportError: cannot import name 'EllipsisType' from 'typing' (/usr/lib/python3.10/typing.py)

haiku and typing are the same version on both machines so guess it must be jaxlib. On both machines I'm on pip==20.0.2 and in a python 3.9.9 virtualenv.

When I try to downgrade to jaxlib==0.1.67 I get:

ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.67 (from versions: 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.10, 0.3.14, 0.3.15)
ERROR: No matching distribution found for jaxlib==0.1.67

I even tried pip install jaxlib==0.1.67 -f https://storage.googleapis.com/jax-releases/jax_releases.html and it doesn't work.

Has anyone experienced the same problem or maybe has a clue of what could be the issue here to help me?


Solution

  • Based on the path in the exception (/usr/lib/python3.10), it looks like you are using python 3.10. There are no python 3.10 wheels for jaxlib==0.1.67 (see pypi). You will have to use python 3.6-3.9.

    If you think you are using python 3.9, then here's a way to clear up confusion when installing packages. Use

    python3.9 -m pip install
    

    to install packages into your python 3.9 environment. Replace python3.9 with whichever python interpreter you want to use.