Search code examples
linuxdockerjax

Why does installing JAX with Docker create such a large image?


I am trying to pip install JAX using Docker and I found that using it just blows up the size of Docker image. The size of image currently is 4.82 GB.

I made sure to bypass caching while installing packages by doing --no-cache-dir. While that did reduce the size, the size is still unreasonable huge.

Here is my Dockerfile -

FROM ubuntu:22.04

WORKDIR /app

RUN apt-get update && apt-get install -y \
    libosmesa6-dev \
    sudo \
    wget \
    curl \
    unzip \
    gcc \
    g++

ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PATH="/root/miniconda3/bin:${PATH}"

RUN mkdir -p ~/miniconda3
RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh
RUN bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
RUN rm -rf ~/miniconda3/miniconda.sh
RUN ~/miniconda3/bin/conda init bash
RUN conda init

RUN pip install --no-cache-dir --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

This is how I built it -

docker build -t tbd_jax .

When I do docker images, I get this -

REPOSITORY   TAG       IMAGE ID       CREATED          SIZE
tbd_jax      latest    812292e2264e   7 minutes ago    4.82GB

After doing docker history --no-trunc tbd_jax:latest -

SIZE      COMMENT
sha256:812292e2264e4340b7715956824055d7409f9546f8dfa54ccad1da056febf300   8 minutes ago    RUN |1 PATH=/root/miniconda3/bin:/root/miniconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin /bin/sh -c pip install --no-cache-dir --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # buildkit   3.54GB    buildkit.dockerfile.v0

Is there something I can do to reduce the size? I am a bit of a Docker and Linux newbie so pardon my slowness.


Solution

  • Note that jax[cuda12_pip] installs all the cuda drivers listed here:

    'cuda12_pip': [
      ...
      "nvidia-cublas-cu12>=12.2.5.6",
      "nvidia-cuda-cupti-cu12>=12.2.142",
      "nvidia-cuda-nvcc-cu12>=12.2.140",
      "nvidia-cuda-runtime-cu12>=12.2.140",
      "nvidia-cudnn-cu12>=8.9",
      "nvidia-cufft-cu12>=11.0.8.103",
      "nvidia-cusolver-cu12>=11.5.2",
      "nvidia-cusparse-cu12>=12.1.2.141",
      "nvidia-nccl-cu12>=2.18.3",
    

    These nvidia driver packages are quite large: for example the nvidia_cublas_cu12 wheel is over 400MB, and nvidia-cudnn-cu12 is over 700MB. You may be able to do better by setting up your docker image with system-native CUDA & CUDNN drivers, installed via apt. You can find a description of the requirements here. You can also use NVIDIA's pre-defined GPU containers, as mentioned here.