Search code examples
gpu

How do GPU handle indirect branches


My understanding of GPUs is that they handle branches by executing all path while suspending instances that are not supposed to execute the path. This works well for if/then/else kind of construct and loops (instance that terminated the loop can be suspended until all instance are suspended).

This flat out does not work if the branch is indirect. But modern GPUs (Fermi and beyond for nVidia, not sure when it appear for AMD, R600 ?) claim to support indirect branches (function pointers, virtual dispatch, ...).

Question is, what kind of magic is going on in the chip to make this happen ?


Solution

  • PTX ISA >=6.0 has indirect branches via jump tables with bra.idx instruction.

    Unlike the regular if (a == 2) { ... } else { ... } code (as well as switch (...) { ... }), the brx.idx solution avoids multiple setp.ne.s32 comparisons:

    #include <cstdio>
    #include <cstdint>
    
    #if (defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)
    #define __LDG_PTR "l"
    #else
    #define __LDG_PTR "r"
    #endif /*(defined(_MSC_VER) && defined(_WIN64)) || defined(__LP64__) || defined(__CUDACC_RTC__)*/
    
    __global__ void kernel(uint32_t* tgt)
    {
            printf("tgt = %d\n", *tgt);
            asm volatile(".reg .u32 r_tgt;");
            asm volatile("ld.u32 r_tgt, [%0];" :: __LDG_PTR(tgt));
            asm volatile("ts: .branchtargets BLK0, BLK1, BEXIT;");
            asm volatile("brx.idx r_tgt, ts;");
            asm volatile("BLK0:");
            printf("BLK0\n");
            asm volatile("ret;\n");
            asm volatile("BLK1:");
            printf("BLK1\n");
            asm volatile("ret;\n");
            asm volatile("BEXIT:");
            printf("BEXIT\n");
            asm volatile("ret;\n");
    }
    
    int main(int argc, char* argv[])
    {
            uint32_t* tgt = nullptr;
            cudaMalloc(&tgt, sizeof(uint32_t));
            uint32_t val = atoi(argv[1]);
            printf("Testing tgt = %d\n", val);
            cudaMemcpy(tgt, &val, sizeof(int), cudaMemcpyHostToDevice);
            kernel<<<1, 1>>>(tgt);
            auto err = cudaDeviceSynchronize();
            if (err != cudaSuccess)
            {
                    fprintf(stderr, "CUDA error: code = %d\n", err);
                    exit(-1);
            }
            return 0;
    }
    

    Full example in Gist.