Search code examples
memorygpujax

JAX produces memory error for simple program on GPU


I installed JAX (pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html) and even for a simple code like

a = jnp.array([1,2,3])
a.dot(a)

I get the following error:

2023-09-08 10:12:55.791658: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:445] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-09-08 10:12:55.791696: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:449] Memory usage: 8058437632 bytes free, 8513978368 bytes total.

It looks like a memory issue. I tried the tips mentioned here

https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html

but to no success.

This is the result of nvidia-smi on my system:

NVIDIA-SMI 470.199.02   Driver Version: 470.199.02   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0 Off |                  N/A |
| N/A   64C    P0    37W /  N/A |    364MiB /  8119MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

It seems strange to me that there seems to be a out of memory issue when it is just a single array.

Any ideas?


Solution

  • I don't believe this is a memory issue; rather it looks like you have a mismatch between your CUDA and CUDNN versions.

    One way to ensure your CUDA versions are compatible is to use the pip-based installation (see JAX pip installation: GPU (CUDA, installed via pip, easier)). This should ensure that you install mutually-compatible CUDA, CUDNN, and jaxlib versions on your system. Installing JAX using pip-installed CUDA looks something like this:

    $ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    It looks like this may be the approach you used; if so, you should check that your system path (e.g. LD_LIBRARY_PATH) is not pre-empting the pip-installed CUDA with a local version. There is some relevant discussion at https://github.com/google/jax/issues/17497.

    If you want to use a local CUDA installation, you can follow JAX pip installation: GPU (CUDA, installed locally, harder), but then it is up to you to ensure that your CUDA, CUDNN, and jaxlib versions are mutually compatible. Installing JAX using local CUDA looks something like this:

    $ pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    

    but if you use this approach, be sure to read the details at the above link.