Search code examples
jaxapptainer

CuDNN error when running JAX on GPU with apptainer


I have an application written in Python 3.10+ with JAX that I would like to run on GPU. I can run containers on my local computer cluster using apptainer (but not Docker) which has an NVIDIA A40 GPU. Based on the proposed Dockerfile for JAX I made an Ubuntu-based image from the following Dockerfile:

FROM nvidia/cuda:11.8.0-devel-ubuntu22.04

RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I then convert the Docker image to an apptainer image using apptainer pull docker://my-image and then run the container using apptainer run --nv docker://my-image as described in the apptainer GPU docs.

Error

When I run the following code

import jax
jax.numpy.array(1.)

JAX immediately crashes with the following error message:

2023-05-11 14:41:50.580441: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 1993, in array
    out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py", line 537, in _convert_element_type
    return convert_element_type_p.bind(operand, new_dtype=new_dtype,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 360, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 363, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 817, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 117, in apply_primitive
    compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/util.py", line 253, in wrapper
    return cached(config._trace_context(), *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/util.py", line 246, in cached
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
    compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
    return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
    self._executable = UnloadedMeshExecutable.from_hlo(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
    xla_executable = dispatch.compile_or_get_cached(
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
    return backend_compile(backend, serialized_computation, compile_options,
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 471, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

What I've tried

Based on a github thread with a similar error message (https://github.com/google/jax/issues/4920), I have tried to add some CUDA paths:

export PATH=/usr/local/cuda-11/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}

However, this did not resolve my problem.

Local docker container works without gpu

When I test the image built from the Dockerfile on my local machine without GPU, everything works fine:

$ docker run -ti my-image python3 -c 'import jax; jax.numpy.array(1.)'
$

Apptainer container detects GPU's

I can confirm that the GPU's are detected in the apptainer container. I get the following output when I run nvidia-smi:

$ apptainer run --nv docker://my-image nvidia-smi
INFO:    Using cached SIF image

==========
== CUDA ==
==========

CUDA Version 11.8.0

Container image Copyright (c) 2016-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license

A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.

Thu May 11 14:34:18 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A40          On   | 00000000:00:08.0 Off |                    0 |
|  0%   31C    P8    31W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A40          On   | 00000000:00:09.0 Off |                    0 |
|  0%   31C    P8    31W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  NVIDIA A40          On   | 00000000:00:0A.0 Off |                    0 |
|  0%   30C    P8    30W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  NVIDIA A40          On   | 00000000:00:0B.0 Off |                    0 |
|  0%   31C    P8    31W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  NVIDIA A40          On   | 00000000:00:0C.0 Off |                    0 |
|  0%   31C    P8    30W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   5  NVIDIA A40          On   | 00000000:00:0D.0 Off |                    0 |
|  0%   32C    P8    31W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   6  NVIDIA A40          On   | 00000000:00:0E.0 Off |                    0 |
|  0%   31C    P8    30W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   7  NVIDIA A40          On   | 00000000:00:0F.0 Off |                    0 |
|  0%   30C    P8    31W / 300W |      0MiB / 46068MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

Edit

Cuda 11 image with CuDNN 8.7 gives different error

When I use a different base image with cudnn 8.7.0:

FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04

RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

I get a different error:

$ apptainer run --nv docker://my-image python3 -c 'import jax; jax.numpy.array(1.)'

Could not load library libcudnn_ops_infer.so.8. Error: libnvrtc.so: cannot open shared object file: No such file or directory
Aborted (core dumped)

Solution

  • Based on the pointers of jakevdp I managed to find a solution. What was needed was:

    • The CUDA 11 image with CuDNN 8.7.
    • The devel instead of the runtime image.

    Together, I could succesfully run JAX on apptainer with the following Dockerfile:

    FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
    
    RUN apt update && apt install python3-pip -y
    RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html