Search code examples
parallel-processingopenclgpusequences

Computing partial sums in OpenCL


A 1D dataset is divided into segments, each work item processes one segment. It read a number of elements from the segment? The number of elements is not known beforehand and differs for each segment.

For example:

+----+----+----+----+----+----+----+----+----+     <-- segments
  A    BCD  E    FG  HIJK   L    M        N        <-- elements in this segment

After all segments have been processes they should write the elements in contiguously output memory, like

A B C D E F G H I J K L M N

So the absolute output position of the elements from one segment depends on the number of elements in the previous segments. E is at position 4 because segment contains 1 element (A) and segment 2 contains 3 elements.


The OpenCL kernel writes the number of elements for each segment into a local/shared memory buffer, and works like this (pseudocode)

kernel void k(
    constant uchar* input,
    global int* output,
    local int* segment_element_counts
) {
    int segment = get_local_id(0);
    int count = count_elements(&input[segment * segment_size]);

    segment_element_counts[segment] = count;

    barrier(CLK_LOCAL_MEM_FENCE);

    ptrdiff_t position = 0;
    for(int previous_segment = 0; previous_segment < segment; ++previous_segment)
        position += segment_element_counts[previous_segment];

    global int* output_ptr = &output[position];
    read_elements(&input[segment * segment_size], output_ptr);
}

So each work item has to calculate a partial sum using a loop, where the work items with larger id do more iterations.

Is there a more efficient way to implement this (each work item calculate a partial sum of a sequence, up to its index), in OpenCL 1.2? OpenCL 2 seems to provide work_group_scan_inclusive_add for this.


Solution

  • You can do N partial (prefix) sums in log2(N) iterations using something like this:

    offsets[get_local_id(0)] = count;
    barrier(CLK_LOCAL_MEM_FENCE);
    
    for (ushort combine = 1; combine < total_num_segments; combine *= 2)
    {
        if (get_local_id(0) & combine)
        {
            offsets[get_local_id(0)] +=
                offsets[(get_local_id(0) & ~(combine * 2u - 1u)) | (combine - 1u)];
        }
        barrier(CLK_LOCAL_MEM_FENCE);
    }
    

    Given segment element counts of

    a     b     c        d
    

    The successive iterations will produce:

    a     b+a   c        d+c
    

    and

    a     b+a   c+(b+a)  (d+c)+(b+a)
    

    Which is the result we want.

    So in the first iteration, we've divided the segment element counts into groups of 2, and sum within them. Then we merge 2 groups at a time into 4 elements, and propagate the result from the first group into the second. We grow the groups again to 8, and so on.

    The key observation is that this pattern also matches the binary representation of the index of each segment:

    0: 0b00  1: 0b01  2: 0b10  3: 0b11
    

    Index 0 performs no sums. Both indices 1 and 3 perform a sum in the first iteration (bit 0/LSB = 1), whereas indices 2 and 3 perform a sum in the second iteration (bit 1 = 1). That explains this line:

        if (get_local_id(0) & combine)
    

    The other statement that really needs an explanation is of course

            offsets[get_local_id(0)] +=
                offsets[(get_local_id(0) & ~(combine * 2u - 1u)) | (combine - 1u)];
    

    Calculating the index at which we find the previous prefix sum we want to accumulate onto our work-item's sum is a little tricky. The subexpression (combine * 2u - 1u) takes the value (2n-1) on each iteration (for n starting at 1):

    1 = 0b001
    3 = 0b011
    7 = 0b111
    …
    

    By bitwise-masking these bit suffixes off (i.e. i & ~x) the work-item index, this gives you the index of the first item in the current group.

    The (combine - 1u) subexpression then gives you the index within the current group of the last item of the first half. Putting the two together gives you the overall index of the item you want to accumulate into the current segment.

    There is one slight ugliness in the result: it's shifted to the left by one: so segment 1 needs to use offsets[0], and so on, while segment 0's offset is of course 0. You can either over-allocate the offsets array by 1 and perform the prefix sums on the subarray starting at index 1 and initialise index 0 to 0, or use a conditional.

    There are probably profiling-driven micro-optimisations you can make to the above code.