I installed JAX (pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) and even for a simple code like
a = jnp.array([1,2,3])
a.dot(a)
I get the following error:
2023-09-08 10:12:55.791658: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:445] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-09-08 10:12:55.791696: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:449] Memory usage: 8058437632 bytes free, 8513978368 bytes total.
It looks like a memory issue. I tried the tips mentioned here
https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
but to no success.
This is the result of nvidia-smi on my system:
NVIDIA-SMI 470.199.02 Driver Version: 470.199.02 CUDA Version: 11.4 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA GeForce ... Off | 00000000:01:00.0 Off | N/A |
| N/A 64C P0 37W / N/A | 364MiB / 8119MiB | 2% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
It seems strange to me that there seems to be a out of memory issue when it is just a single array.
Any ideas?
I don't believe this is a memory issue; rather it looks like you have a mismatch between your CUDA and CUDNN versions.
One way to ensure your CUDA versions are compatible is to use the pip-based installation (see JAX pip installation: GPU (CUDA, installed via pip, easier)). This should ensure that you install mutually-compatible CUDA, CUDNN, and jaxlib versions on your system. Installing JAX using pip-installed CUDA looks something like this:
$ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
It looks like this may be the approach you used; if so, you should check that your system path (e.g. LD_LIBRARY_PATH
) is not pre-empting the pip-installed CUDA with a local version. There is some relevant discussion at https://github.com/google/jax/issues/17497.
If you want to use a local CUDA installation, you can follow JAX pip installation: GPU (CUDA, installed locally, harder), but then it is up to you to ensure that your CUDA, CUDNN, and jaxlib versions are mutually compatible. Installing JAX using local CUDA looks something like this:
$ pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
but if you use this approach, be sure to read the details at the above link.