I need jaxlib==0.1.67
for a project I'm working on, but I can't
downgrade. At the moment I have jaxlib==0.1.75
and my program keeps failing due to an error I can't find a solution to either. I compared all versions of the important packages to another machines versions where my programs runs with no problems and the only difference is the jaxlib version (it's still 0.1.67 on the machine where it runs). I suspect that jaxlib is the issue because the error I get when it's not 0.1.67 is the following:
from haiku import data_structures
File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/data_structures.py", line 17, in <module>
from haiku._src.data_structures import to_immutable_dict
File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/data_structures.py", line 30, in <module>
from haiku._src import utils
File "/net/home/justen/.local/lib/python3.10/site-packages/haiku/_src/utils.py", line 24, in <module>
import jax
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/__init__.py", line 108, in <module>
from .experimental.maps import soft_pmap
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/experimental/maps.py", line 25, in <module>
from .. import numpy as jnp
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/__init__.py", line 16, in <module>
from . import fft
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/numpy/fft.py", line 17, in <module>
from jax._src.numpy.fft import (
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/fft.py", line 19, in <module>
from jax import lax
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/lax/__init__.py", line 334, in <module>
from jax._src.lax.parallel import (
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/lax/parallel.py", line 36, in <module>
from jax._src.numpy import lax_numpy
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 51, in <module>
from jax import ops
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/ops/__init__.py", line 16, in <module>
from jax._src.ops.scatter import (
File "/net/home/justen/.local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 31, in <module>
from typing import EllipsisType
ImportError: cannot import name 'EllipsisType' from 'typing' (/usr/lib/python3.10/typing.py)
haiku and typing are the same version on both machines so guess it must be jaxlib. On both machines I'm on pip==20.0.2 and
in a python 3.9.9 virtualenv.
When I try to downgrade to jaxlib==0.1.67
I get:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.67 (from versions: 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.10, 0.3.14, 0.3.15)
ERROR: No matching distribution found for jaxlib==0.1.67
I even tried pip install jaxlib==0.1.67 -f https://storage.googleapis.com/jax-releases/jax_releases.html
and it doesn't work.
Has anyone experienced the same problem or maybe has a clue of what could be the issue here to help me?
Based on the path in the exception (/usr/lib/python3.10
), it looks like you are using python 3.10. There are no python 3.10 wheels for jaxlib==0.1.67 (see pypi). You will have to use python 3.6-3.9.
If you think you are using python 3.9, then here's a way to clear up confusion when installing packages. Use
python3.9 -m pip install
to install packages into your python 3.9 environment. Replace python3.9
with whichever python interpreter you want to use.