Search code examples
bit-manipulationsimdavxavx2lz77

AVX2 code to find the first longest match of 4-byte string among 8 4-byte targets


I need the fastest (i.e. branchless, minimize uops) AVX2 code equivalent to this one:

prevlen = 0
for i=0..7:
  len = matched_bytes(target[i], src)
  if len > prevlen:
    prevlen = len
    index = i

where target[i] and src are 4-byte strings and matched_bytes returns 0..4 - number of the equal lower bytes:

def matched_bytes(target, src):
  return tzcnt(target ^ src) / 8

The code below takes 15 commands. I can live without the best length, the index is enough for some of my use cases.

Can it be made in fewer commands? I care less about latencies or unfair ALU usage, since it's a part of larger code.

byte_eq = pmovmskb( pcmpeqb( broadcast(src), targets))

// bit 4*i set if byte1..4 is equal
byte_eq1 = flags
byte_eq2 = flags >> 1
byte_eq3 = flags >> 2
byte_eq4 = flags >> 3

// bit 4*i set if at least 1..4 bytes are equal
len1 = byte_eq1 & 0x11111111
len2 = len1 & byte_eq2
len3 = len2 & byte_eq3
len4 = len3 & byte_eq4

// Just one CMOV after the corresponding assignment, interleaved with the previous block
if(len2==0) len2 = len1
if(len3==0) len3 = len2
if(len4==0) len4 = len3

index = lzcnt(len4) / 4
// if len4==0 then no match was found

Solution

  • Here are two strategies:
    A. Pack the compare mask down to 16-bits then use phminposuw.
    B. Transpose bits from pmovmskb such that tzcnt yields the index and length of the best match.

    Method B is probably better. However, it requires extra loads from memory for the shuffle control indices.

    Both methods will use 'trailing bit manipulation' on the comparision mask to ignore bits after a mismatch.


    #include <smmintrin.h> // SSE4.1 intrinsics
    #include <stdint.h>
    #include <stdio.h>
    
    void method_A (uint32_t* arr, uint32_t val) {
        const __m128i neg1 = _mm_set1_epi32(-1);
        __m128i search_value = _mm_set1_epi32(val);
        __m128i row0 = _mm_loadu_si128((__m128i*)&arr[0]);
        __m128i row1 = _mm_loadu_si128((__m128i*)&arr[4]);
        __m128i matched_bytes0 = _mm_cmpeq_epi8(row0, search_value);
        __m128i matched_bytes1 = _mm_cmpeq_epi8(row1, search_value);
        __m128i packed = _mm_packs_epi16(matched_bytes0, matched_bytes1);
        __m128i t1mskc = _mm_or_si128(_mm_xor_si128(packed, neg1), _mm_sub_epi16(packed, neg1));
        __m128i best_match = _mm_minpos_epu16(t1mskc);
        uint32_t match_desc = (uint32_t)_mm_cvtsi128_si32(best_match);
        uint32_t match_index = match_desc >> 16;
        uint32_t match_length = __builtin_ctzl(match_desc | 0x10000) >> 2;
    
        printf("index: %d, length: %d\n", match_index, match_length);
    }
    

    The 16-bit input element to the packing step is a pair of 0 or -1 compare results. These get interpreted as signed 16-bit integers and saturated to signed 8-bit -128 (0x80) to +127 (0x7f).

    input    vpacksswb result
    0xffff      0xff   (-1)
    0xff00      0x80   (large negative: saturates)
    0x00ff      0x7f   (large positive: saturates)
    0x0000      0x00   (0)
    

    This step preserves the ordering when interpreting the result byte as unsigned.

    With further processing, we get the complement of the trailing one bits in each 16-bit lane. This prepares the input for phminposuw such that a 4-byte match would map to the smallest unsigned 16-bit value, while the shortest (no match) gives the highest, with the other 3 possibilities also being in order.


    #include <immintrin.h>
    #include <stdint.h>
    #include <stdio.h>
    
    void method_B (uint32_t* arr, uint32_t val) {
        const __m256i shuf_bytes = _mm256_set_epi8(
            12,8,4,0, 13,9,5,1, 14,10,6,2, 15,11,7,3,
            12,8,4,0, 13,9,5,1, 14,10,6,2, 15,11,7,3        
        );
        const __m256i shuf_dwords = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
    
        __m256i vec = _mm256_loadu_si256((__m256i*)arr);
        __m256i diff = _mm256_xor_si256(vec, _mm256_set1_epi32(val));
        __m256i tzmsk = _mm256_andnot_si256(diff, _mm256_add_epi32(_mm256_set1_epi32(-1), diff));
    
        __m256i t0 = _mm256_shuffle_epi8(tzmsk, shuf_bytes);
        __m256i t1 = _mm256_permutevar8x32_epi32(t0, shuf_dwords);
        uint32_t mask = (uint32_t)_mm256_movemask_epi8(t1);
    
        uint32_t n = (uint32_t)_tzcnt_u32(mask);
        size_t len = 4 - (n >> 3);
        size_t idx = n & 7;
    
        printf("index: %d, length: %d\n", (int)idx, (int)len);
    }
    

    The mask bits are shuffled around such that tzcnt gets both the index and length of the best match:

    bit_0 = dword_0 : bit_31
    bit_1 = dword_1 : bit_31
    bit_2 = dword_2 : bit_31
    ...
    bit_7 = dword_7 : bit_31
    bit_8 = dword_0 : bit_23
    bit_9 = dword_1 : bit_23
    ...
    bit_30 = dword_6: bit_7
    bit_31 = dword_7: bit_7