Search code examples
ubuntupipjax

Installing jaxlib for cuda 11.8


I'm trying to install jax and jaxlib on my Ubuntu 18 with python 3.8 for snerg (https://github.com/google-research/google-research/tree/master/snerg). Unfortunately when I try to install jax and jaxlib for Cuda 11.8 with the following command :

pip install --upgrade jax jaxlib==0.1.69+cuda118 -f https://storage.googleapis.com/jax-releases/jax_releases.html 

I get the following error:

ERROR: Ignored the following versions that require a different python version: 0.4.14 Requires-Python >=3.9
ERROR: Could not find a version that satisfies the requirement jaxlib==0.1.69+cuda118 (from versions: 0.1.32, 0.1.40, 0.1.41, 0.1.42, 0.1.43, 0.1.44, 0.1.46, 0.1.50, 0.1.51, 0.1.52, 0.1.55, 0.1.56, 0.1.57, 0.1.58, 0.1.59, 0.1.60, 0.1.61, 0.1.62, 0.1.63, 0.1.64, 0.1.65, 0.1.66, 0.1.67, 0.1.68, 0.1.69, 0.1.70, 0.1.71, 0.1.72, 0.1.73, 0.1.74, 0.1.75, 0.1.76, 0.3.0, 0.3.2, 0.3.5, 0.3.7, 0.3.8, 0.3.10, 0.3.14, 0.3.15, 0.3.18, 0.3.20, 0.3.22, 0.3.24, 0.3.25, 0.4.0, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.6, 0.4.7, 0.4.9, 0.4.10, 0.4.11, 0.4.12, 0.4.13)
ERROR: No matching distribution found for jaxlib==0.1.69+cuda118

Would appreciate any help. Thanks


Solution

  • jaxlib version 0.1.69 is quite old (it was released in July 2021) CUDA 11.8 was released over a year later, in September 2022. Thus I would not expect there to be pre-built binaries for jaxlib version 0.1.69 targeting CUDA 11.8.

    If possible, your best bet would be to install a newer version of jaxlib, one which has builds targeting CUDA 11.8. The current jaxlib+CUDA GPU installation instructions can be found here.

    If for some reason you absolutely need this very old jaxlib version, you'll probably first have to install an older CUDA version on your system. The CUDA jaxlib installation instructions from jaxlib 0.1.69 can be found here: it looks like it was built to target CUDA 10.1-10.2, 11.0, or 11.1-11.3.