Search code examples
cudaptx

Convergence barrier for branchless CUDA conditional select


I implemented a simple path tracing renderer and I've been profiling my code via ncu-ui. I found something confusing in the SASS generated by the CUDA compiler. Here is snippet that generates the confusing SASS instructions:

// CPT_GPU means '__device__' here
CPT_GPU float ray_intersect(
    const Ray& ray,
    ConstShapePtr shapes,
    ConstAABBPtr s_aabbs,
    ShapeIntersectVisitor& shape_visitor,
    int& min_index,
    const int remain_prims,
    const int cp_base_5,
    float min_dist
) {
    float aabb_tmin = 0;
    #pragma unroll
    for (int idx = 0; idx < remain_prims; idx ++) {
        if (s_aabbs[idx].intersect(ray, aabb_tmin) && aabb_tmin <= min_dist) {
            shape_visitor.set_index(idx);
            float dist = variant::apply_visitor(shape_visitor, shapes[cp_base_5 + idx]);
            bool valid = dist > EPSILON && dist < min_dist;
            min_dist = valid ? dist : min_dist;
            // ------------ the question stems from the line below ---------------
            min_index = valid ? cp_base_5 + idx : min_index;
        }
    }
    return min_dist;
}

The logic, for the sake of clarity, is fairly straightforward: the program traverses through all the triangles, and find the closest hit. If the hit is closer than the currently recorded minimal hit distance (then valid would be true), we record the triangle index in min_index. Otherwise, min_index remains unchanged. Yet in the SASS instructions, the min_index = valid ? cp_base_5 + idx : min_index; generates the following code:

SEL R24, R25, R24, P0      % Select Source with Predicate
BSYNC B8                   % Synchronize Threads on a Convergence Barrier

The profiler tells me that the min_index selection causes considerable amount of warp stalling (with 77.51% being "No instructions" and ~14% being "Branch resolving"). Also, the for loop in the code also introduces much warp stalling (with 50% being "Barrier"). So I wonder:

  • Why would compiler add BSYNC here? Isn't the code branchless and there is no explicit synchronization point? What's the point of this op?
  • Why does the for loop spend so much time stalling, for some barrier? What is this barrier? Why is it added? Is it because of SIMT that all threads in a warp should execute the same instruction so some threads have to wait?

The full code can be found here, if one cares to look into this. Easy compilation with CMake on an NV GPU device.


Solution

  • I've done the profiling, and what surprises me is that on 3060 laptop, the performance gets even worse than the original implementation, yet on a 2060 workstation, the performance is boosted by almost 2.5x. I really don't understand why this is so, and one of my friend tells me that maybe for 3060 (related arch), the compiler is less performant (not verified) since the same program only needs 64 registers on 2060 machine while 72 registers are needed for 3060. This is indeed weird, so I will only break down what boosts the speed on the 2060 workstation.

    As pointed out by @Homer512 and @Johan, BSYNC here is related to branch divergence. The paper recommended by @Homer512 illustrates this, and I put a figure from that paper below: enter image description here

    I think the paper is inspiring and based on the idea of loop merging, the code is first refactored into the following:

        float aabb_tmin = 0;
        int8_t tasks[32] = {0}, cnt = 0;          // 32 bytes
    #pragma unroll
        for (int idx = 0; idx < remain_prims; idx++) {
            // if current ray intersects primitive at [idx], tasks will store it
            int valid_intr = s_aabbs[idx].intersect(ray, aabb_tmin);       // valid intersect
            tasks[cnt] = valid_intr ? (int8_t)idx : (int8_t)0;
            cnt += valid_intr;
            // note that __any_sync here won't work well
        }
    #pragma unroll
        for (int i = 0; i < cnt; i++) {
            int idx = tasks[i];
            shape_visitor.set_index(idx);
            float dist = variant::apply_visitor(shape_visitor, shapes[cp_base_5 + idx]);
            bool valid = dist > EPSILON && dist < min_dist;
            min_dist = valid ? dist : min_dist;
            min_index = valid ? cp_base_5 + idx : min_index;
        }
        return min_dist;
    

    This implementation avoids convergence synchronization that happens every iteration, where there might be only parts of the threads being active. This implementation should work in the following way: for the first N iterations (N is the minimum triangle hit count), all the threads are active, and until the M+1 iteration (M is the maximum triangle hit count), the for loop exits. So... the threads are grouped, to some extent.

    • This implementation only takes 61ms per frame (no BVH, 531 triangles) , compared to the old code (111ms per frame).

    Later @Homer512 suggests to replace int8_t tasks[32] to a bit mask. So here is latest version:

        float aabb_tmin = 0;
        unsigned int tasks = 0;          // 32 bytes
    
    #pragma unroll
        for (int idx = 0; idx < remain_prims; idx++) {
            // if current ray intersects primitive at [idx], tasks will store it
            int valid_intr = s_aabbs[idx].aabb.intersect(ray, aabb_tmin) && (aabb_tmin < min_dist);
            tasks |= (valid_intr << idx);
            // note that __any_sync here won't work well
        }
    #pragma unroll
        while (tasks) {
            int idx = __ffs(tasks) - 1; // find the first bit that is set to 1, note that __ffs is 
            tasks &= ~((unsigned int)1 << idx); // clear bit in case it is found again
            shape_visitor.set_index(idx);
            float dist = variant::apply_visitor(shape_visitor, shapes[cp_base_5 + idx]);
            bool valid = dist > EPSILON && dist < min_dist;
            min_dist = valid ? dist : min_dist;
            min_index = valid ? cp_base_5 + idx : min_index;
        }
         return min_dist;
    
    • This bit operation further boosts the rendering speed (from 60ms to 45ms each frame). I guess since I've already used too many registers (86) and spill them to L1 cache with --maxrregcount = 64 to have enough warps on the SMs, the int8_t tasks[32] might cause more unexpected register spilling, so when this part is replaced, the performance sees further improvements.

    So whether have I solved the problem of BSYNC? Maybe not (as there are still some... barrier, and L2 Theoretical Sectors Global Excessive problem seems pretty severe): enter image description here

    , but anyway, the if convergence synchronization is no longer the most significant bottleneck and I've really learned a lot. Thanks again @Homer512 for the in-depth followup!