Search code examples
simdavxbitmaskavx2prefix-sum

AVX2 vectorization for code similar to prefix sum (decrement by count of preceding matches in short fixed-length arrays)


I have some performance-critical code that looks like this:

uint8_t v[128], w[128];

for (int i = 0; i < 128; ++i) {
    // ...
    for (int j = 0; j < 128; ++j) {
        // ...
        if (v[j] == w[i]) {
            // ...
            for (int k = j; k < 128; ++k) {
                v[k]--;
            }
            break; // <-- this is critical
        }
    }
}

Note that break is critical there, not just an early-abort optimization. If I take it out, my tests fail.

I'm trying to speed this code up using AVX2.

I first load v to an array of 4 AVX2 registers. The if (v[j] == w[i]) line is easily replaced by _mm256_cmpeq_epi8, broadcasting w[i] to all lanes.

By breaking up the j-loop into 4 loops of 32 iterations each, I know exactly which element of the AVX2 array I am currently on; assume e.g. a hit on the if for some j with 32 <= j < 64, which corresponds to the AVX2 array element of index 1. For the innermost (k) loop, I thought of using _mm256_movemask_epi8, then the bsf instruction to extract the lowest set bit (call it b), followed by one of the techniques here to do an "inverse _mm256_movemask_epi8". Now I can just subtract the result from v[1] -- for v[2] and v[3] I just subtract _mm256_set1_epi8(1).

Another option is to do a prefix sum. Since multiple lanes may match in the comparison, I could do _mm256_cmpgt_epi8 with zero after the prefix sum.

Although I haven't worked out the details, I intend to avoid the explicit branch (break in the innermost loop body) since it's probably cheaper than the inevitable pipeline flush, given that the branches will be completely random.

Still I can't shake the feeling that there must be a simpler solution. For instance, for the "inverse _mm256_movemask_epi8" solution, I don't have a general mask, but only 32 possible values out of 2^32: 0x0, 0x1, 0x3, 0x7, 0xF, 0x1F, 0x3F, ..., 0xFFFFFFFF.

Can anyone think of a different, ideally simpler, strategy?


Solution

  • One w[i] at a time: find first match and decrement all later v[j]

    (This strategy didn't end up being very efficient due to the loop-carried dependency in compare/tzcnt/generate vectors of 0 or -1/vpaddb, even if keeping v[] in four __m256i vectors to avoid store/reload. See the OP's comments below.)

    Your idea of doing it this way avoids correctness problems with needing to change v[j] before comparing, which are a showstopper for most of my earlier ideas.

    Yes, with four vpcmpeqb / vpmovmskb we can get a 128-bit bitmap which we search with tzcnt. You probably don't want to use a __m128i load of that to find the first non-zero byte of the 128-bit bitmap, instead just two 64-bit chunks with a cmov (C ?: ternary) to select between tzcnt_u64(low_half) + 64 + _tzcnt_u64(high_half). You'll need shift/OR to combine pairs of uint32_t mask0 = _mm256_movemask_epi8(...) into uint64_t mask_low = mask0 | ((uint64_t)mask1<<64);

    Making this fully branchless (or branching the same way every time) would require extending v[] to uint8_t v[128 + 96] so it's safe to write 3 vectors past the end, so we can just unconditionally do store( add(load, set1(-1))) for the 3 vectors after the one containing the first match.

    To get a vector that's -1 at or above a given count, we can load from a sliding window onto an array of uint8_t maskbuf[] = {0,0,0,0,...,-1,-1,-1,-1} as in What is the best way to loop AVX for un-even non-aligned array? / Vectorizing with unaligned buffers: using VMASKMOVPS: generating a mask from a misalignment count? Or not using that insn at all

    Another option instead of making a vector mask, just do four unaligned vectors starting at the first match, so uint8_t v[128 + 127]. But that will perform terribly, with store-forwarding stalls when our vector loads next iteration don't line up with our stores from this iteration. (What are the costs of failed store-to-load forwarding on x86?)


    Ideas that don't account for v[j] maybe changing before we compare

    I wrote the rest of this answer before realizing that a v[j] can have changed in one outer iteration before we compare it again next outer iteration. I'm going to post anyway in case some of these ideas are salvageable)

    These are based on the idea that we decrement every element of v[] by the number of elements in w[] that match any element of v[] before or at this element.

    Thinking first about how to work between 32-byte chunks, if we can get all elements beyond the first 32, the decrement count for v[j] for j=32..127 due only to the first block is the number of w[0..127] elements that matched any element of v[0..31].

    pcmpestrm can do any against any between 16-byte vectors (https://www.strchr.com/strcmp_and_strlen_using_sse_4.2 has a good summary of the things it can do in general, despite the article about strcmp and strlen where these instructions are worse than SSE2, let alone AVX2.)

    We can OR together the masks. (Or if 0 doesn't occur in the arrays, pcmpistrm can be used, which is significantly faster, 3 uops instead of 9 on Intel / 7 on Zen 4. https://uops.info/ . Throughput ratios aren't that extreme, like 5c vs. 3c on Intel or 3c vs. 2c on Zen 4, since the extra uops can run on other ports. But we do have some other work to do in parallel so it's not great.)

    pcmpestrm might be better than manually brute-force comparing to do every rotation of 16-byte lanes (vpshufb) on normal and swapped copies of a __m256i. Like two chains of vpshufb to rotate 16-byte lanes by 1, on a vector from memory and the same vector with 16-byte halves swapped, so most of the shuffles are in-lane. So you'd be running a lot of vpshufb / vpcmpeqb / vpor. That can be prefix-summed with some shuffles, but note that vpsadbw against _mm256_setzero_si256() will hsum every chunk of 8 bytes into the bottom of the qword.

    Algorithmic improvements, not just brute-force?

    (Also not considering that each v[j] can be different on different outer iterations.)

    Non-brute-force options include sort + binary-search, or check a bitmap or bool array if we're considering something as slow as checking every v[j] separately against w[0..127].

    Or making a histogram, uint8_t wcounts[256];, then one-pass prefix-sum (which you use as a decrement count) like decrement += wcounts[ v[j] ]; v[j] -= decrement;. Histogramming isn't very SIMD-friendly, and it's so short you might not want to use separate arrays of counts to hide store-forwarding latency bottlenecks on repeated elements. But if you do, SIMD is great for vertically summing count0[i] += count1[i], just 256 bytes to process, worth it if w[] often contains a lot of the same value.

    But this doesn't work because v[j] might have changed by 1 while looking at w[0] and now match or not match different elements.

    There might still be something useful we can do to reduce the need for SIMD brute-force searching, instead using reverse lookup tables that we update on the fly. But I'm not seeing anything obvious.