I have an application written in Python 3.10+ with JAX that I would like to run on GPU. I can run containers on my local computer cluster using apptainer (but not Docker) which has an NVIDIA A40 GPU. Based on the proposed Dockerfile for JAX I made an Ubuntu-based image from the following Dockerfile:
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I then convert the Docker image to an apptainer image using apptainer pull docker://my-image
and then run the container using apptainer run --nv docker://my-image
as described in the apptainer GPU docs.
When I run the following code
import jax
jax.numpy.array(1.)
JAX immediately crashes with the following error message:
2023-05-11 14:41:50.580441: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:429] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py", line 1993, in array
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/lax.py", line 537, in _convert_element_type
return convert_element_type_p.bind(operand, new_dtype=new_dtype,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 360, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 363, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/core.py", line 817, in process_primitive
return primitive.impl(*tracers, **params)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 117, in apply_primitive
compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args),
File "/usr/local/lib/python3.10/dist-packages/jax/_src/util.py", line 253, in wrapper
return cached(config._trace_context(), *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/util.py", line 246, in cached
return f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 208, in xla_primitive_callable
compiled = _xla_callable_uncached(lu.wrap_init(prim_fun), prim.name,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 254, in _xla_callable_uncached
return computation.compile(_allow_propagation_to_outputs=allow_prop).unsafe_call
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 2816, in compile
self._executable = UnloadedMeshExecutable.from_hlo(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/interpreters/pxla.py", line 3028, in from_hlo
xla_executable = dispatch.compile_or_get_cached(
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 526, in compile_or_get_cached
return backend_compile(backend, serialized_computation, compile_options,
File "/usr/local/lib/python3.10/dist-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/jax/_src/dispatch.py", line 471, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
Based on a github thread with a similar error message (https://github.com/google/jax/issues/4920), I have tried to add some CUDA paths:
export PATH=/usr/local/cuda-11/bin${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
However, this did not resolve my problem.
When I test the image built from the Dockerfile on my local machine without GPU, everything works fine:
$ docker run -ti my-image python3 -c 'import jax; jax.numpy.array(1.)'
$
I can confirm that the GPU's are detected in the apptainer
container. I get the following output when I run nvidia-smi
:
$ apptainer run --nv docker://my-image nvidia-smi
INFO: Using cached SIF image
==========
== CUDA ==
==========
CUDA Version 11.8.0
Container image Copyright (c) 2016-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
This container image and its contents are governed by the NVIDIA Deep Learning Container License.
By pulling and using the container, you accept the terms and conditions of this license:
https://developer.nvidia.com/ngc/nvidia-deep-learning-container-license
A copy of this license is made available in this container at /NGC-DL-CONTAINER-LICENSE for your convenience.
Thu May 11 14:34:18 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01 Driver Version: 515.65.01 CUDA Version: 11.7 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A40 On | 00000000:00:08.0 Off | 0 |
| 0% 31C P8 31W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A40 On | 00000000:00:09.0 Off | 0 |
| 0% 31C P8 31W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 NVIDIA A40 On | 00000000:00:0A.0 Off | 0 |
| 0% 30C P8 30W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 NVIDIA A40 On | 00000000:00:0B.0 Off | 0 |
| 0% 31C P8 31W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 4 NVIDIA A40 On | 00000000:00:0C.0 Off | 0 |
| 0% 31C P8 30W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 5 NVIDIA A40 On | 00000000:00:0D.0 Off | 0 |
| 0% 32C P8 31W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 6 NVIDIA A40 On | 00000000:00:0E.0 Off | 0 |
| 0% 31C P8 30W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 7 NVIDIA A40 On | 00000000:00:0F.0 Off | 0 |
| 0% 30C P8 31W / 300W | 0MiB / 46068MiB | 0% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
Edit
When I use a different base image with cudnn 8.7.0:
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
I get a different error:
$ apptainer run --nv docker://my-image python3 -c 'import jax; jax.numpy.array(1.)'
Could not load library libcudnn_ops_infer.so.8. Error: libnvrtc.so: cannot open shared object file: No such file or directory
Aborted (core dumped)
Based on the pointers of jakevdp I managed to find a solution. What was needed was:
devel
instead of the runtime
image.Together, I could succesfully run JAX on apptainer
with the following Dockerfile:
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04
RUN apt update && apt install python3-pip -y
RUN pip install "jax[cuda11_cudnn86]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html