Search code examples
c++cudaprefix-sum

Parallel prefix sum with multiple elements per thread without using thrust


I'm trying to perform an inclusive scan to find the cumulative sum of an array. Following the advice given by harrism here, I'm using the procedure given here, but following the advice of those authors, I'm trying to write code that has each thread calculate 4 elements instead of one to mask memory latency.

I am staying away from thrust as performance is essential, and I need multi-stream capability. I have only just discovered CUB, and that will be my next effort, but I would like a multi-block solution and would also like to know where I've gone wrong on my existing code, just as an exercise to better understand CUDA.

The code below allocates 4 data elements to each block, where each block must have a multiple of 32 threads. My data will have a multiple of 128 threads so this restriction is acceptable to me. Enough shared memory is allocated to each block for the 4*blockDim.x elements plus an additional 32 elements to sum between warps. scanBlockAnyLength then adds the necessary offset to correct mismatch between warps, saving the final value of each warp to dev_blockSum in device global memory. sumWarp4_32 then scans this array to find the final to correct the mismatch between blocks, which is then added on in kernel_sumBlock

#include<cuda.h>
#include<iostream>
using std::cout;
using std::endl;

#define MAX_THREADS 1024
#define MAX_BLOCKS 65536
#define N 512

__device__ float sumWarp4_128(float* ptr, const int tidx = threadIdx.x) {
    const unsigned int lane = tidx & 31;
    const unsigned int warpid = tidx >> 5; //32 threads per warp

    unsigned int i = warpid*128+lane; //first element of block data set this thread looks at

    if( lane >= 1 ) ptr[i] += ptr[i-1];
    if( lane >= 2 ) ptr[i] += ptr[i-2];
    if( lane >= 4 ) ptr[i] += ptr[i-4];
    if( lane >= 8 ) ptr[i] += ptr[i-8];
    if( lane >= 16 ) ptr[i] += ptr[i-16];

    if( lane==0 ) ptr[i+32] += ptr[i+31];

    if( lane >= 1 ) ptr[i+32] += ptr[i+32-1];
    if( lane >= 2 ) ptr[i+32] += ptr[i+32-2];
    if( lane >= 4 ) ptr[i+32] += ptr[i+32-4];
    if( lane >= 8 ) ptr[i+32] += ptr[i+32-8];
    if( lane >= 16 ) ptr[i+32] += ptr[i+32-16];

    if( lane==0 ) ptr[i+64] += ptr[i+63];

    if( lane >= 1 ) ptr[i+64] += ptr[i+64-1];
    if( lane >= 2 ) ptr[i+64] += ptr[i+64-2];
    if( lane >= 4 ) ptr[i+64] += ptr[i+64-4];
    if( lane >= 8 ) ptr[i+64] += ptr[i+64-8];
    if( lane >= 16 ) ptr[i+64] += ptr[i+64-16];

    if( lane==0 ) ptr[i+96] += ptr[i+95];

    if( lane >= 1 ) ptr[i+96] += ptr[i+96-1];
    if( lane >= 2 ) ptr[i+96] += ptr[i+96-2];
    if( lane >= 4 ) ptr[i+96] += ptr[i+96-4];
    if( lane >= 8 ) ptr[i+96] += ptr[i+96-8];
    if( lane >= 16 ) ptr[i+96] += ptr[i+96-16];

    return ptr[i+96];
}
__host__ __device__ float sumWarp4_32(float* ptr, const int tidx = threadIdx.x) {
    const unsigned int lane = tidx & 31;
    const unsigned int warpid = tidx >> 5; //32 elements per warp

    unsigned int i = warpid*32+lane; //first element of block data set this thread looks at

    if( lane >= 1 ) ptr[i] += ptr[i-1];
    if( lane >= 2 ) ptr[i] += ptr[i-2];
    if( lane >= 4 ) ptr[i] += ptr[i-4];
    if( lane >= 8 ) ptr[i] += ptr[i-8];
    if( lane >= 16 ) ptr[i] += ptr[i-16];

    return ptr[i];
}
__device__ float sumBlock4(float* ptr, const int tidx = threadIdx.x, const int bdimx = blockDim.x ) {
    const unsigned int lane = tidx & 31;
    const unsigned int warpid = tidx >> 5; //32 threads per warp

    float val = sumWarp4_128(ptr);
    __syncthreads();//should be included

    if( tidx==bdimx-1 ) ptr[4*bdimx+warpid] = val;
    __syncthreads();

    if( warpid==0 ) sumWarp4_32((float*)&ptr[4*bdimx]);
    __syncthreads();

    if( warpid>0 ) {
        ptr[warpid*128+lane] += ptr[4*bdimx+warpid-1];
        ptr[warpid*128+lane+32] += ptr[4*bdimx+warpid-1];
        ptr[warpid*128+lane+64] += ptr[4*bdimx+warpid-1];
        ptr[warpid*128+lane+96] += ptr[4*bdimx+warpid-1];
    }
    __syncthreads();
    return ptr[warpid*128+lane+96];
}
__device__ void scanBlockAnyLength4(float *ptr, float* dev_blockSum, const float* dev_input, float* dev_output, const int idx = threadIdx.x, const int bdimx = blockDim.x, const int bidx = blockIdx.x) {

    const unsigned int lane = idx & 31;
    const unsigned int warpid = idx >> 5;

    ptr[lane+warpid*128] = dev_input[lane+warpid*128+bdimx*bidx*4];
    ptr[lane+warpid*128+32] = dev_input[lane+warpid*128+bdimx*bidx*4+32];
    ptr[lane+warpid*128+64] = dev_input[lane+warpid*128+bdimx*bidx*4+64];
    ptr[lane+warpid*128+96] = dev_input[lane+warpid*128+bdimx*bidx*4+96];
    __syncthreads();

    float val = sumBlock4(ptr);
    __syncthreads();
    dev_blockSum[0] = 0.0f;
    if( idx==0 ) dev_blockSum[bidx+1] = ptr[bdimx*4-1];

    dev_output[lane+warpid*128+bdimx*bidx*4] = ptr[lane+warpid*128];
    dev_output[lane+warpid*128+bdimx*bidx*4+32] = ptr[lane+warpid*128+32];
    dev_output[lane+warpid*128+bdimx*bidx*4+64] = ptr[lane+warpid*128+64];
    dev_output[lane+warpid*128+bdimx*bidx*4+96] = ptr[lane+warpid*128+96];
    __syncthreads();
}
__global__ void kernel_sumBlock(float* dev_blockSum, const float* dev_input, float*   dev_output ) {
    extern __shared__ float ptr[];
    scanBlockAnyLength4(ptr,dev_blockSum,dev_input,dev_output);
}
__global__ void kernel_offsetBlocks(float* dev_blockSum, float* dev_arr) {
    const int tidx = threadIdx.x;
    const int bidx = blockIdx.x;
    const int bdimx = blockDim.x;

    const int lane = tidx & 31;
    const int warpid = tidx >> 5;
    if( warpid==0 ) sumWarp4_32(dev_blockSum);
    float val = dev_blockSum[warpid];
    dev_arr[warpid*128+lane] += val;
    dev_arr[warpid*128+lane+32] += val;
    dev_arr[warpid*128+lane+64] += val;
    dev_arr[warpid*128+lane+96] += val;
}
void scan4( const float input[], float output[]) {
    int blocks = 2;
    int threadsPerBlock = 64; //multiple of 32
    int smemsize = (threadsPerBlock*4+32)*sizeof(float);

    float* dev_input, *dev_output;
    cudaMalloc((void**)&dev_input,blocks*threadsPerBlock*4*sizeof(float));
    cudaMalloc((void**)&dev_output,blocks*threadsPerBlock*4*sizeof(float));

    float *dev_blockSum;
    cudaMalloc((void**)&dev_blockSum,blocks*sizeof(float));

    int offset = 0;
    int Nrem = N;
    int chunksize;
    while( Nrem ) {
        chunksize = max(Nrem,blocks*threadsPerBlock*4);
        cudaMemcpy(dev_input,(void**)&input[offset],chunksize*sizeof(float),cudaMemcpyHostToDevice);
        kernel_sumBlock<<<blocks,threadsPerBlock,smemsize>>>(dev_blockSum,dev_input,dev_output);
        kernel_offsetBlocks<<<blocks,threadsPerBlock>>>(dev_blockSum,dev_output);
        cudaMemcpy((void**)&output[offset],dev_output,chunksize*sizeof(float),cudaMemcpyDeviceToHost);
        offset += chunksize;
        Nrem -= chunksize;
    }
    cudaFree(dev_input);
    cudaFree(dev_output);
}

int main() {
    float h_vec[N], sol[N];
    for( int i = 0; i < N; i++ ) h_vec[i] = (float)i+1.0f;

    scan4(h_vec,sol);

    cout << "solution:" << endl;
    for( int i = 0; i < N; i++ ) cout << i << " " << (i+2)*(i+1)/2 << " " << sol[i] << endl;
    return 0;
}

To my eye, the code is throwing errors because the lines in sumWarp4_128 are not executed in order within a warp. I.e, the if( lane==0 ) lines are executing before the other logical blocks that precede it. I thought this was not possible within a warp.

If I __syncthreads() before and after the lane==0 calls, I get some new exotic error that I just can't figure out.

Any help to point out where I've gone wrong would be appreciated


Solution

  • The code you are writing has race conditions due to not synchronizing between threads that are sharing data. While it is true that this can be done on current hardware for communication within a warp (so-called warp-synchronous programming), it is highly discouraged because the race conditions in the code could cause it to fail on possible future hardware.

    While it is true that you will get higher performance by processing multiple items per thread, 4 is not a magic number -- you should make this a tunable parameter if possible. CUDPP uses 8 per thread, for example.

    I would highly recommend that you use CUB for this. You should use cub::BlockLoad() to load multiple items per thread and cub::BlockScan() to scan them. Then you would just need some code to combine multiple blocks. The most bandwidth-efficient way to do this is to use the "Reduce-Scan-Scan" approach that Thrust uses. First reduce each block (cub::BlockReduce) and store the sum from each block to a blockSums array. Then scan that array to get the per-block offset. Then perform a cub::BlockScan on the blocks and add the previously computed per-block offset to each element.