Search code examples
c++cudahistogramgpu-warp

Compute per-warp histogram without shared memory


Problem Compute a per-warp histogram of sorted sequence of numbers held by individual threads in a warp.

Example:

lane: 0123456789...          31
val:  222244455777799999 ..

The result must be held by N lower threads in a warp (where N is the amount of unique numbers), e.g.:

lane 0: val=2, num=4 (2 occurs 4 times)
lane 1: val=4, num=3 (4 occurs 3 times)
lane 2: val=5, num=2 ...
lane 3: val=7, num=4
lane 4: val=9, num=5
...

Note that, it is essentially not required for a sequence of 'val' to be sorted: it's only necessary for equal numbers to be grouped together, i.e.: 99955555773333333...

Possible solution This can be done quite efficiently with shuffle intrinsics, though my question is whether it's possible to do this without using shared memory at all (I mean shared memory is a scarce resource, I need it somewhere else) ?

For simplicity, I execute this code for a single warp only (so that printf works fine):

__device__ __inline__ void sorted_seq_histogram()
{
    uint32_t tid = threadIdx.x, lane = tid % 32;
    uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced

    printf("%d: val = %d\n", lane, val);
    uint32_t num = 1;

    uint32_t allmsk = 0xffffffffu, shfl_c = 31;
    for(int i = 1; i <= 16; i *= 2) {

#if 1
        uint32_t xval = __shfl_down_sync(allmsk, val, i),
                 xnum = __shfl_down_sync(allmsk, num, i);
        if(lane + i < 32) {
            if(val == xval)
                num += xnum;
        }
#else  // this is a (hopefully) optimized version of the code above
        asm(R"({
          .reg .u32 r0,r1;
          .reg .pred p;
          shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
          shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
          @p setp.eq.s32 p, %1, r0;
          @p add.u32 r1, r1, %0;
          @p mov.u32 %0, r1;
        })"
        : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
#endif
    }
    // shfl.sync wraps around: so thread 0 gets the value of thread 31
    bool leader = val != __shfl_sync(allmsk, val, lane - 1);
    auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
    auto total = __popc(OK); // the total number of unique numbers found

    auto lanelt = (1 << lane) - 1;
    auto idx = __popc(OK & lanelt);

    printf("%d: val = %d; num = %d; total: %d; idx = %d; leader: %d\n", lane, val, num, total, idx, leader);

    __shared__ uint32_t sh[64];
    if(leader) {   // here we need shared memory :(
        sh[idx] = val;
        sh[idx + 32] = num;
    }
    __syncthreads();

    if(lane < total) {
        val = sh[lane], num = sh[lane + 32];
    } else {
        val = 0xDEADBABE, num = 0;
    }
    printf("%d: final val = %d; num = %d\n", lane, val, num);
}

Here is my GPU output:

0: val = 27
1: val = 27
2: val = 28
3: val = 28
4: val = 28
5: val = 28
6: val = 29
7: val = 29
8: val = 29
9: val = 29
10: val = 30
11: val = 30
12: val = 30
13: val = 30
14: val = 31
15: val = 31
16: val = 31
17: val = 31
18: val = 32
19: val = 32
20: val = 32
21: val = 32
22: val = 32
23: val = 33
24: val = 33
25: val = 33
26: val = 33
27: val = 34
28: val = 34
29: val = 34
30: val = 34
31: val = 35
0: val = 27; num = 2; total: 9; idx = 0; leader: 1
1: val = 27; num = 1; total: 9; idx = 1; leader: 0
2: val = 28; num = 4; total: 9; idx = 1; leader: 1
3: val = 28; num = 3; total: 9; idx = 2; leader: 0
4: val = 28; num = 2; total: 9; idx = 2; leader: 0
5: val = 28; num = 1; total: 9; idx = 2; leader: 0
6: val = 29; num = 4; total: 9; idx = 2; leader: 1
7: val = 29; num = 3; total: 9; idx = 3; leader: 0
8: val = 29; num = 2; total: 9; idx = 3; leader: 0
9: val = 29; num = 1; total: 9; idx = 3; leader: 0
10: val = 30; num = 4; total: 9; idx = 3; leader: 1
11: val = 30; num = 3; total: 9; idx = 4; leader: 0
12: val = 30; num = 2; total: 9; idx = 4; leader: 0
13: val = 30; num = 1; total: 9; idx = 4; leader: 0
14: val = 31; num = 4; total: 9; idx = 4; leader: 1
15: val = 31; num = 3; total: 9; idx = 5; leader: 0
16: val = 31; num = 2; total: 9; idx = 5; leader: 0
17: val = 31; num = 1; total: 9; idx = 5; leader: 0
18: val = 32; num = 5; total: 9; idx = 5; leader: 1
19: val = 32; num = 4; total: 9; idx = 6; leader: 0
20: val = 32; num = 3; total: 9; idx = 6; leader: 0
21: val = 32; num = 2; total: 9; idx = 6; leader: 0
22: val = 32; num = 1; total: 9; idx = 6; leader: 0
23: val = 33; num = 4; total: 9; idx = 6; leader: 1
24: val = 33; num = 3; total: 9; idx = 7; leader: 0
25: val = 33; num = 2; total: 9; idx = 7; leader: 0
26: val = 33; num = 1; total: 9; idx = 7; leader: 0
27: val = 34; num = 4; total: 9; idx = 7; leader: 1
28: val = 34; num = 3; total: 9; idx = 8; leader: 0
29: val = 34; num = 2; total: 9; idx = 8; leader: 0
30: val = 34; num = 1; total: 9; idx = 8; leader: 0
31: val = 35; num = 1; total: 9; idx = 8; leader: 1
0: final val = 27; num = 2
1: final val = 28; num = 4
2: final val = 29; num = 4
3: final val = 30; num = 4
4: final val = 31; num = 4
5: final val = 32; num = 5
6: final val = 33; num = 4
7: final val = 34; num = 4
8: final val = 35; num = 1
9: final val = -559039810; num = 0
10: final val = -559039810; num = 0
11: final val = -559039810; num = 0
12: final val = -559039810; num = 0
13: final val = -559039810; num = 0
14: final val = -559039810; num = 0
15: final val = -559039810; num = 0
16: final val = -559039810; num = 0
17: final val = -559039810; num = 0
18: final val = -559039810; num = 0
19: final val = -559039810; num = 0
20: final val = -559039810; num = 0
21: final val = -559039810; num = 0
22: final val = -559039810; num = 0
23: final val = -559039810; num = 0
24: final val = -559039810; num = 0
25: final val = -559039810; num = 0
26: final val = -559039810; num = 0
27: final val = -559039810; num = 0
28: final val = -559039810; num = 0
29: final val = -559039810; num = 0
30: final val = -559039810; num = 0
31: final val = -559039810; num = 0

Question Is it possible to do this without using shared memory? Somehow, I cannot figure it out with all these brain-twisting shuffle intrinsics..


Solution

  • I think I found the solution: as paleonix also pointed out, the problem is that we need to compute the Nth bit set.

    There is actually pretty interesting PTX intrinsic called fns.b32 which does exactly that. However, on my SM30 architecture it maps to something crazy when I run disassembler.

    Anyway, we also have the fast popcount intrinsic on GPU which can be used to compute the position of the Nth bit set in logarithmic time. Below is the complete code which now does not require shared memory at all:

    EDITED: interestingly enough, apart from NVIDIA, AMD seems to provide a so-called "warp_permute" intruction which is an opposite of __shfl_sync in the sense that threads from a warp write to some destination lane: AMD warp_permute.

    EDITED: small optimization using BFE intrinsic

    #define PRINTZ(fmt, ...) printf(fmt"\n", ##__VA_ARGS__)
    
    // extracts bitfield from src of length 'width' starting at startIdx
    __device__ __forceinline__ uint32_t bfe(uint32_t src, uint32_t startIdx, uint32_t width)
    {
        uint32_t bit;
        asm volatile("bfe.u32 %0, %1, %2, %3;" : "=r"(bit) : "r"(src), "r"(startIdx), "r"(width));
        return bit;
    }
    
    __device__ __inline__ void sorted_seq_histogram()
    {
        uint32_t tid = threadIdx.x, lane = tid % 32;
        uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced
    
        PRINTZ("%d: val = %d", lane, val);
        uint32_t num = 1;
    
        const uint32_t allmsk = 0xffffffffu, shfl_c = 31;
    
        // shfl.sync wraps around: so thread 0 gets the value of thread 31
        bool leader = val != __shfl_sync(allmsk, val, lane - 1);
        auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
        uint32_t pos = 0, N = lane+1; // each thread searches Nth bit set in 'OK' (1-indexed)
    
        for(int i = 1; i <= 16; i *= 2) {
    
            uint32_t j = 16 / i;
            uint32_t mval = bfe(OK, pos, j); // extract j bits starting at pos from OK
            auto dif = N - __popc(mval);
            if((int)dif > 0) {
                N = dif, pos += j;
            }
    
    #if 0
            uint32_t xval = __shfl_down_sync(allmsk, val, i),
                     xnum = __shfl_down_sync(allmsk, num, i);
            if(lane + i < 32) {
                if(val == xval)
                    num += xnum;
            }
    #else  // this is a (hopefully) optimized version of the code above
            asm(R"({
              .reg .u32 r0,r1;
              .reg .pred p;
              shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
              shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
              @p setp.eq.s32 p, %1, r0;
              @p add.u32 r1, r1, %0;
              @p mov.u32 %0, r1;
            })"
            : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
    #endif
        }
        num = __shfl_sync(allmsk, num, pos); // read from pos-th thread
        val = __shfl_sync(allmsk, val, pos); // read from pos-th thread
    
        auto total = __popc(OK); // the total number of unique numbers found
        if(lane >= total) {
            num = 0xDEADBABE;
        }
        PRINTZ("%d: final val = %d; num = %d", lane, val, num);
    }
    

    And the program output:

    0: val = 27
    1: val = 27
    2: val = 28
    3: val = 28
    4: val = 28
    5: val = 28
    6: val = 29
    7: val = 29
    8: val = 29
    9: val = 29
    10: val = 30
    11: val = 30
    12: val = 30
    13: val = 30
    14: val = 31
    15: val = 31
    16: val = 31
    17: val = 31
    18: val = 32
    19: val = 32
    20: val = 32
    21: val = 32
    22: val = 32
    23: val = 33
    24: val = 33
    25: val = 33
    26: val = 33
    27: val = 34
    28: val = 34
    29: val = 34
    30: val = 34
    31: val = 35
    0: final val = 27; num = 2;
    1: final val = 28; num = 4;
    2: final val = 29; num = 4;
    3: final val = 30; num = 4;
    4: final val = 31; num = 4;
    5: final val = 32; num = 5;
    6: final val = 33; num = 4;
    7: final val = 34; num = 4;
    8: final val = 35; num = 1;
    9: final val = 35; num = -559039810;
    10: final val = 35; num = -559039810;
    11: final val = 35; num = -559039810;
    12: final val = 35; num = -559039810;
    13: final val = 35; num = -559039810;
    14: final val = 35; num = -559039810;
    15: final val = 35; num = -559039810;
    16: final val = 35; num = -559039810;
    17: final val = 35; num = -559039810;
    18: final val = 35; num = -559039810;
    19: final val = 35; num = -559039810;
    20: final val = 35; num = -559039810;
    21: final val = 35; num = -559039810;
    22: final val = 35; num = -559039810;
    23: final val = 35; num = -559039810;
    24: final val = 35; num = -559039810;
    25: final val = 35; num = -559039810;
    26: final val = 35; num = -559039810;
    27: final val = 35; num = -559039810;
    28: final val = 35; num = -559039810;
    29: final val = 35; num = -559039810;
    30: final val = 35; num = -559039810;
    31: final val = 35; num = -559039810;