Search code examples
c++x86simdintrinsicsavx2

How to implement an efficient _mm256_madd_epi8 dot-products of groups of four i8 elements?


Intel provides a C style function named _mm256_madd_epi16, which basically

__m256i _mm256_madd_epi16 (__m256i a, __m256i b)

Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst.

Now I have two __m256i variables, each of them has 32 8-bit int in it.

I want to implement the same functionality as _mm256_madd_epi16 does, but each int32_t element in the result __m256i is the sum of four products of signed char instead of two pairs of signed int16_t. A dot-product of four int8_t elements within each 32-bit chunk.

I can do that in a scalar loop:

  alignas(32) uint32_t res[8] = {0};
  for (int i = 0; i < 32; ++i)
      res[i / 4] += _mm256_extract_epi8(a, i) * _mm256_extract_epi8(b, i);
  return _mm256_load_si256((__m256i*)res);

Note that the multiply result is sign-extended to int before adding, and that the _mm256_extract_epi8 helper function1 returns signed __int8. Nevermind that the total is uint32_t instead of int32_t; it can't overflow anyway with only four 8x8 => 16-bit numbers to add.

It looks very ugly, and doesn't runs efficiently unless the compiler does some magic to do it with SIMD instead of compiling as written to scalar extraction.


Footnote 1: _mm256_extract_epi8 is not an intrinsic. vpextrb only works for the low lane of a 256-bit vector, and this helper function may allow an index that isn't a compile-time constant.


Solution

  • pmaddubsw: usable if at least one input is non-negative (and thus can be treated as unsigned)

    If one of your inputs is known to always be non-negative, you can use it as the unsigned input to pmaddubsw; the 8->16 bit equivalent of pmaddwd. It adds pairs of u8*i8 -> i16 products, with signed saturation to 16 bits. But saturation is impossible with one input being at most 127 instead of 255. (127*-128 = -0x3f80, so twice that still fits in i16.)

    After pmaddubsw, use pmaddwd against _mm256_set1_epi16(1) to hsum pairs of elements with correct handling of the signs. (This is usually more efficient than manually sign-extending the 16-bit elements to 32 to add them.)

    __m256i sum16 = _mm256_maddubs_epi16(a, b);   // pmaddubsw
    __m256i sum32 = _mm256_madd_epi16(sum16, _mm256_set1_epi16(1)); // pmaddwd
    

    (pmaddwd for a horizontal 16=>32-bit sums of pairs within 4-byte elements is higher latency on some CPUs than shift / and / add, but does treat both inputs as signed for sign-extending to 32-bit. And it's only a single uop so it's good for throughput, especially if the surrounding code doesn't bottleneck on the same execution ports.)


    General case (both inputs may be negative)

    A recent answer on AVX-512BW emulation of _mm512_dpbusd_epi32 AVX-512VNNI instruction came up with a good trick of splitting one input into MSB and the low 7 bits so vpmaddubsw (_mm256_maddubs_epi16) could be used without overflow. We can borrow that trick and negate while hsumming because the place value of the MSB is -2^7 rather than the 2^7 that the unsigned input of vpmaddubsw treats it as.

    // Untested.  __m128i version would need SSSE3
    __m256i dotprod_i8_to_i32(__m256i v1, __m256i v2)
    {
        const __m256i highest_bit = _mm256_set1_epi8(0x80);
    
        __m256i msb = _mm256_maddubs_epi16(_mm256_and_si256(v1, highest_bit), v2);     // 0 or 2^7
        __m256i low7 = _mm256_maddubs_epi16(_mm256_andnot_si256(highest_bit, v1), v2);
    
        low7 = _mm256_madd_epi16(low7, _mm256_set1_epi16(1));  // hsum i16 pairs to i32
        msb  = _mm256_madd_epi16(msb,  _mm256_set1_epi16(1));
        return _mm256_sub_epi32(low7, msb);  // place value of the MSB was negative
    
       // equivalent to the below, but that needs an extra constant
    //    msb = _mm256_madd_epi16(msb,  _mm256_set1_epi16(-1));   // the place-value was actually - 2^7
    //    return _mm256_add_epi32(low7, msb);
    
       // also equivalent to vpmaddwd with -1 for both parts
       // return sub(msb, low7)
       // which is cheaper because set1(-1) is just vpcmpeqd not a load.
    }
    

    This avoids signed saturation: the max multiplier for one side is 128 (the MSB being set and treated as unsigned). 128 * -128 = -16384, twice that is -32768 = -0x8000 = bit-pattern 0x8000. Or 128 * 127 * 2 = 0x7f00 as the highest positive result.

    This is 7 uops (4 for the multiply units) vs. 9 uops (4 shifts + 2 multiplies) for the version below.

    AVX-512VNNI _mm256_dpbusd_epi32 (or 512), or AVX_VNNI _mm256_dpbusd_avx_epi32 (VPDPBUSD) is like vpmaddubsw (u8*i8 products) but adding to an existing sum, and hsumming 4 products within one byte in a single instruction. (i32 += four u8 * i8). The same split trick works, _mm256_sub_epi32(low7_prods, msb_prods) but we can skip the madd_epi16 (vpmaddwd) i16 to i32 horizontal sum steps.

    (Other VNNI instructions include vpdpbusds (same as vpdpbusd but with signed saturation instead of wrapping). Either way the saturation is to i32, not i16 like vpmaddubsw, so it only saturates if the accumulator input is non-zero. If one input is non-negative so can be treated as unsigned, this does the whole job in one instruction without splitting. And vpdpwssd[s], MAC of signed words with or without saturations, like vpmaddwd but with an accumulator operand.)

    // Ice Lake (AVX-512 version only) or Alder Lake (AVX_VNNI), or Zen 4
    __m256i dotprod_i8_to_i32_vnni(__m256i v1, __m256i v2)
    {
        const __m256i highest_bit = _mm256_set1_epi8(0x80);
        __m256i msb = _mm256_and_si256(v1, highest_bit);
        __m256i low7 = _mm256_andnot_si256(highest_bit, v1);
    
       // or just _mm256_dpbusd_epi32 for the EVEX version
        msb = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), msb, v2);     // 0 or 2^7
        low7 = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), low7, v2);
    
        return _mm256_sub_epi32(low7, msb);  // place value of the MSB was negative
    }
    

    AVX-512 without AVX-512VNNI can use the AVX2 version unchanged, or widened to 512. Or might be able to apply the sign bit by shifting turning it into a mask (vptestmb) and zeroing some bytes of the input (zero-masked vpmovdqu8) for horizontal sums of 4-byte chunks into 32-bit elements (vdbpsadbw against zero with an identity shuffle-control). But no, that doesn't sign-extend the 8-bit inputs before adding them since it's unsigned differences. Perhaps with a range-shift to unsigned first (e.g. a zero-masked xor with 0x80) then adding 4*128? Anyway, then msb = _mm256_slli_epi32(dword_hsums_of_input_b, 7) to be used the same way the code above uses its msb variable. If this even works, IDK if it saves uops. Feedback welcome, or post an AVX-512BW answer.


    The other way: unpacking and sign-extending to 16-bit

    The obvious solution would be to unpack your input bytes to 16-bit elements with zero or sign-extension. Then you can use pmaddwd twice, and add the results.

    If your inputs are coming from memory, loading them with vpmovsxbw might make sense. e.g.

    __m256i a = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)&arr1[i]);
    __m256i b = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*)&arr2[i]);
    

    But now you have the 4 bytes that you want spread out across two dwords, so you'd have to shuffle the result of one _mm256_madd_epi16(a,b). You could maybe use vphaddd to shuffle and add two 256-bit vectors of products into one 256-bit vector of results you want, but that's a lot of shuffling.

    So instead, I think we want to generate two 256-bit vectors from each 256-bit input vector: one with the high byte in each word sign-extended to 16, and the other with the low byte sign extended. We can do that with 3 shifts (for each input)

     __m256i a = _mm256_loadu_si256(const  __m256i*)&arr1[i]);
     __m256i b = _mm256_loadu_si256(const  __m256i*)&arr2[i]);
    
     __m256i a_high = _mm256_srai_epi16(a, 8);     // arithmetic right shift sign extends
         // some compilers may only know the less-descriptive _mm256_slli_si256 name for vpslldq
     __m256i a_low =  _mm256_bslli_epi128(a, 1);   // left 1 byte = low to high in each 16-bit element
             a_low =  _mm256_srai_epi16(a_low, 8); // arithmetic right shift sign extends
    
        // then same for b_low / b_high
    
     __m256i prod_hi = _mm256_madd_epi16(a_high, b_high);
     __m256i prod_lo = _mm256_madd_epi16(a_low, b_low);
    
     __m256i quadsum = _m256_add_epi32(prod_lo, prod_hi);
    

    As an alternative to vplldq by 1 byte, vpsllw by 8 bits __m256i a_low = _mm256_slli_epi16(a, 8); is the more "obvious" way to shift low to high within each word, and may be better if the surrounding code bottlenecks on shuffles. But normally it's worse because this code heavily bottlenecks on shift + vec-int multiply.

    On KNL, you could use AVX512 vprold z,z,i (Agner Fog doesn't show a timing for AVX512 vpslld z,z,i) because it doesn't matter what you shift or shuffle into the low byte of each word; this is just setup for an arithmetic right shift.

    Execution port bottlenecks:

    Haswell runs vector shifts and vector-integer multiply only on port 0, so this badly bottlenecks on that. (Skylake is better: p0/p1). http://agner.org/optimize/.

    We can use a shuffle (port 5) instead of the left shift as setup for an arithmetic right shift. This improves throughput and even reduces latency by reducing resource conflicts.

    But we can avoid the shuffle control vector by using vpslldq to do a vector byte shift. It's still an in-lane shuffle (shifting in zeros at the end of each lane), so it still has single-cycle latency. (My first idea was vpshufb with a control vector like 14,14, 12,12, 10,10, ..., then vpalignr, then I remembered that simple old pslldq has an AVX2 version. There are two names for the same instruction. I like _mm256_bslli_epi128 because the b for byte-shift distinguishes it as a shuffle, unlike the within-element bit-shifts. I didn't check which compiler supports what name for the 128-bit or 256-bit versions of the intrinsic.)

    This also helps on AMD Zen 1. Vector shifts only run on one execution unit (P2), but shuffles can run on P1 or P2.

    I haven't looked at AMD Ryzen execution port conflicts, but I'm pretty sure this won't be worse on any CPU (except KNL Xeon Phi, where AVX2 ops on elements smaller than a dword are all super slow). Shifts and in-lane shuffles are the same number of uops and same latency.

    If any elements are known non-negative, sign-extend = zero-extend

    (Or better, use pmaddubsw as shown in the first section.)

    Zero-extending is cheaper than manually sign-extending, and avoids port bottlenecks. a_low and/or b_low can be created with _mm256_and_si256(a, _mm256_set1_epi16(0x00ff)).

    a_high and/or b_high can be created with a shuffle instead of shift. (pshufb zeros the element when the shuffle-control vector has its high bit set).

     const _mm256i pshufb_emulate_srl8 = _mm256_set_epi8(
                   0x80,15, 0x80,13, 0x80,11, ...,
                   0x80,15, 0x80,13, 0x80,11, ...);
    
     __m256i a_high = _mm256_shuffle_epi8(a, pshufb_emulate_srl8);  // zero-extend
    

    Shuffle throughput is also limited to 1 per clock on mainstream Intel, so you could bottleneck on shuffles if you go overboard. But at least it's not the same port as the multiply. If only the high bytes are known non-negative, replacing vpsra/lw with vpshufb could help. Unaligned loads so those high bytes are low bytes could be more helpful, setting up for vpand for a_low and/or b_low.