Search code examples
cx86-64intrinsicsavxavx2

Fastest way to implement _mm256_mullo_epi4 using AVX2


For a research problem I need a very efficient 4 bits multiplication (only the low 4 bit are needed) implementation using AVX2/AVX instructions.

My current approach is:

__m256i _mm256_mullo_epi4(const __m256i a, const __m256i b) {
    __m256i mask_f_0 = _mm256_set1_epi16(0x000f);
    __m256i tmp_mul_0 = _mm256_and_si256(_mm256_mullo_epi16(a, b), mask_f_0);
    __m256i tmp_mul_1 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a,   4), _mm256_srli_epi16(b,   4)), mask_f_0);
    __m256i tmp_mul_2 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a,   8), _mm256_srli_epi16(b,   8)), mask_f_0);
    __m256i tmp_mul_3 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a,  12), _mm256_srli_epi16(b,  12)), mask_f_0);
    __m256i tmp1 = _mm256_xor_si256(tmp_mul_0, _mm256_slli_epi16(tmp_mul_1, 4));
    __m256i tmp2 = _mm256_xor_si256(tmp1, _mm256_slli_epi16(tmp_mul_2, 8));
    __m256i tmp  = _mm256_xor_si256(tmp2, _mm256_slli_epi16(tmp_mul_3, 12));
    return tmp;
}

This implementation leverages the relative expensive _mm256_mullo_epi16 instruction 4 times to compute each 4bit limb in separate. Can this be done somehow faster? To put it more precisely: is it possible to reduce the number of needed instructions?


Solution

  • I don't see an obvious way to reduce the number of multiplications, e.g., mask enough bytes of two inputs to get two separate products with one multiplication. Even vpmaddubsw is difficult to exploit, since it takes one operand as signed 8bit values (and would require a lot of shifting to get the nibbles in the correct positions).

    You can however reduce the amount of shifting, at the cost of a few more maskings:

    Pseudo code:

    (a*b) & 0xf           = 0,0,0,ab
    (a>>4)*(b&0xf0)       = *,*,ab,0
    (a>>8)*(b&0xf00)      = *,ab,0,0
    (a>>12)*(b&0xf000)    = ab,0,0,0
    

    With intrinsics (untested):

    __m256i _mm256_mullo_epi4(const __m256i a, const __m256i b) {
        __m256i mask_000f = _mm256_set1_epi16(0x000f);
        __m256i mask_00f0 = _mm256_set1_epi16(0x00f0);
        __m256i mask_0f00 = _mm256_set1_epi16(0x0f00);
        __m256i mask_f000 = _mm256_set1_epi16(0xf000);
        __m256i tmp_mul_0 = _mm256_and_si256(_mm256_mullo_epi16(a, b), mask_000f);
        __m256i tmp_mul_1 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a,  4), _mm256_and_si256(b, mask_00f0)), mask_00f0);
        __m256i tmp_mul_2 = _mm256_and_si256(_mm256_mullo_epi16(_mm256_srli_epi16(a,  8), _mm256_and_si256(b, mask_0f00)), mask_0f00);
        __m256i tmp_mul_3 =                  _mm256_mullo_epi16(_mm256_srli_epi16(a, 12), _mm256_and_si256(b, mask_f000));
        __m256i tmp1 = _mm256_xor_si256(tmp_mul_0, tmp_mul_1);
        __m256i tmp2 = _mm256_xor_si256(tmp_mul_2, tmp_mul_3);
        __m256i tmp  = _mm256_xor_si256(tmp1, tmp2);
        return tmp;
    }
    

    This requires 4 multiplications and 3 shifts, but 9 bit-operations, compared to 4 multiplications, 9 shifts and 7-bit-operations (technically, masking tmp_mul_3 was not necessary and a compiler might be able to optimize it away).

    So overall 16 uops instead of 19.