Search code examples
cudacuda-driver

How do the CUDA Runtime's current device and the driver context stack interact?


The CUDA Runtime has a notion of a "current device", while the CUDA Driver does not. Instead, the driver has a stack of context, where the "current context" is at the top of the stack.

How do the two interact? That is, how do Driver API calls affect the Runtime API's current device, and how does changing the current device affect the Driver API's context stack or other state?

Somewhat-related question: how can I mix cuda driver api with cuda runtime api?


Solution

  • Runtime current device -> Driver context stack

    If you set the current device (with cudaSetDevice()), then the primary context of the chosen device is placed at the top of the stack.

    • If the stack had been empty, it's pushed onto the stack.
    • If the stack had been non-empty, it replaces the top of the stack.

    Driver context stack -> Runtime current device

    (This part I'm not 100% sure about, so take it with a grain of salt.)

    The Runtime will report the current device to be the device of the current context - whether it's a primary context or not.

    If the context stack is empty, the Runtime's current device will be reported as 0.

    A program to illustrate this behavior:

    #include <cuda/api.hpp>
    #include <iostream>
    
    void report_current_device()
    {
        std::cout << "Runtime reports a current device index of: "
            << cuda::device::current::detail_::get_id() << '\n';
    }
    
    int main()
    {
        namespace context = cuda::context::detail_;
        namespace cur_dev = cuda::device::current::detail_;
        namespace pc = cuda::device::primary_context::detail_;
        namespace cur_ctx = cuda::context::current::detail_;
        using std::cout;
    
        cuda::device::id_t dev_idx[2];
        cuda::context::handle_t pc_handle[2];
        
        cuda::initialize_driver();
        dev_idx[0] = cur_dev::get_id();
        report_current_device();
        dev_idx[1] = (dev_idx[0] == 0) ? 1 : 0;
        pc_handle[0] = pc::obtain_and_increase_refcount(dev_idx[0]);
        cout << "Obtained primary context handle for device " 
             << dev_idx[0]<< '\n';
        pc_handle[1] = pc::obtain_and_increase_refcount(dev_idx[1]);
        cout << "Obtained primary context handle for device " 
             << dev_idx[1]<< '\n';
        report_current_device();
        cur_ctx::push(pc_handle[1]);
        cout << "Pushed primary context handle for device " 
             << dev_idx[1] << " onto the stack\n";
        report_current_device();
        auto ctx = context::create_and_push(dev_idx[0]);
        cout << "Created a new context for device " << dev_idx[0]
             << " and pushed it onto the stack\n";
        report_current_device();
        cur_ctx::push(ctx);
        cout << "Pushed primary context handle for device " << dev_idx[0] 
             << " onto the stack\n";
        report_current_device();
        cur_ctx::push(pc_handle[1]);
        cout << "Pushed primary context for device " << dev_idx[1] 
             << " onto the stack\n";
        report_current_device();
        pc::decrease_refcount(dev_idx[1]);
        cout << "Deactivated/destroyed primary context for device " 
             << dev_idx[1] << '\n';
        report_current_device();
    }
    

    ... which results in:

    Runtime reports a current device index of: 0
    Obtained primary context handle for device 0
    Obtained primary context handle for device 1
    Runtime reports a current device index of: 0
    Pushed primary context handle for device 1 onto the stack
    Runtime reports a current device index of: 1
    Created a new context for device 0 and pushed it onto the stack
    Runtime reports a current device index of: 0
    Pushed primary context handle for device 0 onto the stack
    Runtime reports a current device index of: 0
    Pushed primary context for device 1 onto the stack
    Runtime reports a current device index of: 1
    Deactivated/destroyed primary context for device 1
    Runtime reports a current device index of: 1
    

    The program uses this library of mine.