Search code examples
dockertensorflowgopytorchnvidia

How to pass `--gpus all` option to Docker with Go SDK?


I have seen how to do some basic commands such as running a container, pulling images, listing images, etc from the SDK examples.

I am working on a project where I need to use the GPU from within the container.

My system has GPU, I have installed the drivers, and I have also installed the nvidia-container-runtime.

If we remove Go SDK from the scene for a moment, I can run the following command to get the nvidia-smi output on my host system:

docker run -it --rm --gpus all nvidia/cuda:10.0-base nvidia-smi

I have to do this via the SDK. Here is the code to start with. This code prints "hello world". But in actual I will be running nvidia-smi command at that place:

package main

import (
    "context"
    "os"

    "github.com/docker/docker/api/types"
    "github.com/docker/docker/api/types/container"
    "github.com/docker/docker/client"
    "github.com/docker/docker/pkg/stdcopy"
)

func main() {
    ctx := context.Background()
    cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
    if err != nil {
        panic(err)
    }

    RunContainer(ctx, cli)
}

func RunContainer(ctx context.Context, cli *client.Client) {
    reader, err := cli.ImagePull(ctx, "nvidia/cuda:10.0-base", types.ImagePullOptions{})
    if err != nil {
        panic(err)
    }

    defer reader.Close()
    // io.Copy(os.Stdout, reader)

    resp, err := cli.ContainerCreate(ctx, &container.Config{
        Image: "nvidia/cuda:10.0-base",
        Cmd:   []string{"echo", "hello world"},
        // Tty:   false,
    }, nil, nil, nil, "")

    if err != nil {
        panic(err)
    }

    if err := cli.ContainerStart(ctx, resp.ID, types.ContainerStartOptions{}); err != nil {
        panic(err)
    }

    statusCh, errCh := cli.ContainerWait(ctx, resp.ID, container.WaitConditionNotRunning)

    select {
    case err := <-errCh:
        if err != nil {
            panic(err)
        }
    case <-statusCh:
    }

    out, err := cli.ContainerLogs(ctx, resp.ID, types.ContainerLogsOptions{ShowStdout: true})
    if err != nil {
        panic(err)
    }

    stdcopy.StdCopy(os.Stdout, os.Stderr, out)
}

Solution

  • see: https://github.com/docker/cli/blob/9ac8584acfd501c3f4da0e845e3a40ed15c85041/cli/command/container/opts.go#L594

    import "github.com/docker/cli/opts"
    
    // ...
    
    gpuOpts := opts.GpuOpts{}
    gpuOpts.Set("all")
    
    resp, err := cli.ContainerCreate(ctx, &container.Config{
        Image: "nvidia/cuda:10.0-base",
        Cmd:   []string{"echo", "hello world"},
        // Tty:   false,
    }, &container.HostConfig{Resources: container.Resources{DeviceRequests: gpuOpts.Value()}}, nil, nil, "")