I have followed the objax documentation to install the library with GPU support: https://objax.readthedocs.io/en/stable/installation_setup.html
i.e.
pip install --upgrade objax
CUDA_VERSION=11.6
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`
However the last step doesn't work. I get the following error message:
ERROR: Could not find a version that satisfies the requirement jaxlib==0.3.15+cuda116 (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15) ERROR: No matching distribution found for jaxlib==0.3.15+cuda116
I have tried with multiple versions of python/CUDA, but I always get this error.
Executing pip install --upgrade pip
at the begining does not help.
System description:
JAX recently updated its GPU installation instructions, which you can find here: https://github.com/google/jax#pip-installation-gpu-cuda
In particular, the CUDA wheels are now located at https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
So, for example, you can install JAX with
$ pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
and replace cuda11
and cudnn805
respectively with the appropriate CUDA and CUDNN version for your system, ensuring that they match the versions listed in the index at the above URL.
I've sent a pull request to the objax repository to update the instructions you were following: https://github.com/google/objax/pull/246