Search code examples
performancex86sseshufflesimd

Efficient sse shuffle mask generation for left-packing byte elements


What would be an efficient way to optimize the following code with sse ?

uint16_t change1= ... ;
uint8_t* pSrc   = ... ;
uint8_t* pDest  = ... ;

if(change1 & 0x0001) *pDest++ = pSrc[0];
if(change1 & 0x0002) *pDest++ = pSrc[1];
if(change1 & 0x0004) *pDest++ = pSrc[2];
if(change1 & 0x0008) *pDest++ = pSrc[3];

if(change1 & 0x0010) *pDest++ = pSrc[4];
if(change1 & 0x0020) *pDest++ = pSrc[5];
if(change1 & 0x0040) *pDest++ = pSrc[6];
if(change1 & 0x0080) *pDest++ = pSrc[7];

if(change1 & 0x0100) *pDest++ = pSrc[8];
if(change1 & 0x0200) *pDest++ = pSrc[9];
if(change1 & 0x0400) *pDest++ = pSrc[10];
if(change1 & 0x0800) *pDest++ = pSrc[11];

if(change1 & 0x1000) *pDest++ = pSrc[12];
if(change1 & 0x2000) *pDest++ = pSrc[13];
if(change1 & 0x4000) *pDest++ = pSrc[14];
if(change1 & 0x8000) *pDest++ = pSrc[15];

So far I am using a quite big lookup table for it, but I really want to get rid of it:

SSE3Shuffle::Entry& e0 = SSE3Shuffle::g_Shuffle.m_Entries[change1];
_mm_storeu_si128((__m128i*)pDest, _mm_shuffle_epi8(*(__m128i*)pSrc, e0.mask));
pDest += e0.offset;

Solution

  • Assuming:

    change1 = _mm_movemask_epi8(bytemask);
    offset = popcnt(change1);
    

    On large buffers, using two shuffles and a 1 KiB table is only ~10% slower than using 1 shuffle and a 1MiB table. My attempts at generating the shuffle mask via prefix sums and bit twiddling are about about half the speed of the table based methods (solutions using pext/pdep were not explored).

    Reducing table size: Use two lookups into a 2 KiB table instead of 1 lookup into a 1 MiB table. Always keep the top-most byte - if that byte is to be discarded then it doesn't matter what byte is at that position (down to 7-bit indices, or 1 KiB table). Further reduce possible combinations by manually packing the two bytes in each 16-bit lane (down to a 216 byte table).

    The following example strips whitespace from text using SSE4.1. If only SSSE3 is available then blendv can be emulated. The 64-bit halves are re-combined by overlapping writes to memory, but they could be re-combined in the xmm register (as seen in the AVX2 example).

    #include <stdint.h>
    #include <smmintrin.h> // SSE4.1
    
    size_t despacer (void* dst_void, void* src_void, size_t length)
    {
        uint8_t* src = (uint8_t*)src_void;
        uint8_t* dst = (uint8_t*)dst_void;
    
        if (length >= 16) {
            // table of control characters (space, tab, newline, carriage return)
            const __m128i lut_cntrl = _mm_setr_epi8(' ', 0, 0, 0, 0, 0, 0, 0, 0, '\t', '\n', 0, 0, '\r', 0, 0);
    
            // bits[4:0] = index -> ((trit_d * 0) + (trit_c * 9) + (trit_b * 3) + (trit_a * 1))
            // bits[15:7] = popcnt
            const __m128i sadmask = _mm_set1_epi64x(0x8080898983838181);
    
            // adding 8 to each shuffle index is cheaper than extracting the high qword
            const __m128i offset = _mm_cvtsi64_si128(0x0808080808080808);
    
            // shuffle control indices
            static const uint64_t table[27] = {
                0x0000000000000706, 0x0000000000070600, 0x0000000007060100, 0x0000000000070602,
                0x0000000007060200, 0x0000000706020100, 0x0000000007060302, 0x0000000706030200,
                0x0000070603020100, 0x0000000000070604, 0x0000000007060400, 0x0000000706040100,
                0x0000000007060402, 0x0000000706040200, 0x0000070604020100, 0x0000000706040302,
                0x0000070604030200, 0x0007060403020100, 0x0000000007060504, 0x0000000706050400,
                0x0000070605040100, 0x0000000706050402, 0x0000070605040200, 0x0007060504020100,
                0x0000070605040302, 0x0007060504030200, 0x0706050403020100
            };
    
            const uint8_t* end = &src[length & ~15];
            do {
                __m128i v = _mm_loadu_si128((__m128i*)src);
                src += 16;
    
                // detect spaces
                __m128i mask = _mm_cmpeq_epi8(_mm_shuffle_epi8(lut_cntrl, v), v);
    
                // shift w/blend: each word now only has 3 states instead of 4
                // which reduces the possiblities per qword from 128 to 27
                v = _mm_blendv_epi8(v, _mm_srli_epi16(v, 8), mask);
    
                // extract bitfields describing each qword: index, popcnt
                __m128i desc = _mm_sad_epu8(_mm_and_si128(mask, sadmask), sadmask);
                size_t lo_desc = (size_t)_mm_cvtsi128_si32(desc);
                size_t hi_desc = (size_t)_mm_extract_epi16(desc, 4);
    
                // load shuffle control indices from pre-computed table
                __m128i lo_shuf = _mm_loadl_epi64((__m128i*)&table[lo_desc & 0x1F]);
                __m128i hi_shuf = _mm_or_si128(_mm_loadl_epi64((__m128i*)&table[hi_desc & 0x1F]), offset);
    
                // store an entire qword then advance the pointer by how ever
                // many of those bytes are actually wanted. Any trailing
                // garbage will be overwritten by the next store.
                // note: little endian byte memory order
                _mm_storel_epi64((__m128i*)dst, _mm_shuffle_epi8(v, lo_shuf));
                dst += (lo_desc >> 7);
                _mm_storel_epi64((__m128i*)dst, _mm_shuffle_epi8(v, hi_shuf));
                dst += (hi_desc >> 7);
            } while (src != end);
        }
    
        // tail loop
        length &= 15;
        if (length != 0) {
            const uint64_t bitmap = 0xFFFFFFFEFFFFC1FF;
            do {
                uint64_t c = *src++;
                *dst = (uint8_t)c;
                dst += ((bitmap >> c) & 1) | ((c + 0xC0) >> 8);
            } while (--length);
        }
    
        // return pointer to the location after the last element in dst
        return (size_t)(dst - ((uint8_t*)dst_void));
    }
    

    Whether the tail loop should be vectorized or use cmov is left as an exercise for the reader. Writing each byte unconditionally/branchlessly is fast when the input is unpredictable.


    Using AVX2 to generate the shuffle control mask using an in-register table is only slightly slower than using large precomputed tables.

    #include <stdint.h>
    #include <immintrin.h>
    
    // probably needs improvment...
    size_t despace_avx2_vpermd(const char* src_void, char* dst_void, size_t length)
    {
        uint8_t* src = (uint8_t*)src_void;
        uint8_t* dst = (uint8_t*)dst_void;
    
        const __m256i lut_cntrl2    = _mm256_broadcastsi128_si256(_mm_setr_epi8(' ', 0, 0, 0, 0, 0, 0, 0, 0, '\t', '\n', 0, 0, '\r', 0, 0));
        const __m256i permutation_mask = _mm256_set1_epi64x( 0x0020100884828180 );
        const __m256i invert_mask = _mm256_set1_epi64x( 0x0020100880808080 ); 
        const __m256i zero = _mm256_setzero_si256();
        const __m256i fixup = _mm256_set_epi32(
            0x08080808, 0x0F0F0F0F, 0x00000000, 0x07070707,
            0x08080808, 0x0F0F0F0F, 0x00000000, 0x07070707
        );
        const __m256i lut = _mm256_set_epi32(
            0x04050607, // 0x03020100', 0x000000'07
            0x04050704, // 0x030200'00, 0x0000'0704
            0x04060705, // 0x030100'00, 0x0000'0705
            0x04070504, // 0x0300'0000, 0x00'070504
            0x05060706, // 0x020100'00, 0x0000'0706
            0x05070604, // 0x0200'0000, 0x00'070604
            0x06070605, // 0x0100'0000, 0x00'070605
            0x07060504  // 0x00'000000, 0x'07060504
        );
    
        // hi bits are ignored by pshufb, used to reject movement of low qword bytes
        const __m256i shuffle_a = _mm256_set_epi8(
            0x7F, 0x7E, 0x7D, 0x7C, 0x7B, 0x7A, 0x79, 0x78, 0x07, 0x16, 0x25, 0x34, 0x43, 0x52, 0x61, 0x70,
            0x7F, 0x7E, 0x7D, 0x7C, 0x7B, 0x7A, 0x79, 0x78, 0x07, 0x16, 0x25, 0x34, 0x43, 0x52, 0x61, 0x70
        );
    
        // broadcast 0x08 then blendd...
        const __m256i shuffle_b = _mm256_set_epi32(
            0x08080808, 0x08080808, 0x00000000, 0x00000000,
            0x08080808, 0x08080808, 0x00000000, 0x00000000
        );
    
        for( uint8_t* end = &src[(length & ~31)]; src != end; src += 32){
            __m256i r0,r1,r2,r3,r4;
            unsigned int s0,s1;
    
            r0 = _mm256_loadu_si256((__m256i *)src); // asrc
    
            // detect spaces
            r1 = _mm256_cmpeq_epi8(_mm256_shuffle_epi8(lut_cntrl2, r0), r0);
    
            r2 = _mm256_sad_epu8(zero, r1);
            s0 = (unsigned)_mm256_movemask_epi8(r1);
            r1 = _mm256_andnot_si256(r1, permutation_mask);
    
            r1 = _mm256_sad_epu8(r1, invert_mask); // index_bitmap[0:5], low32_spaces_count[7:15]
    
            r2 = _mm256_shuffle_epi8(r2, zero);
    
            r2 = _mm256_sub_epi8(shuffle_a, r2); // add space cnt of low qword
            s0 = ~s0;
    
            r3 = _mm256_slli_epi64(r1, 29); // move top part of index_bitmap to high dword
            r4 = _mm256_srli_epi64(r1, 7); // number of spaces in low dword 
    
            r4 = _mm256_shuffle_epi8(r4, shuffle_b);
            r1 = _mm256_or_si256(r1, r3);
    
            r1 = _mm256_permutevar8x32_epi32(lut, r1);
            s1 = _mm_popcnt_u32(s0);
            r4 = _mm256_add_epi8(r4, shuffle_a);
            s0 = s0 & 0xFFFF; // isolate low oword
    
            r2 = _mm256_shuffle_epi8(r4, r2);
            s0 = _mm_popcnt_u32(s0);
    
            r2 = _mm256_max_epu8(r2, r4); // pin low qword bytes
    
            r1 = _mm256_xor_si256(r1, fixup);
    
            r1 = _mm256_shuffle_epi8(r1, r2); // complete shuffle mask
    
            r0 = _mm256_shuffle_epi8(r0, r1); // despace!
    
            _mm_storeu_si128((__m128i*)dst, _mm256_castsi256_si128(r0));
            _mm_storeu_si128((__m128i*)&dst[s0], _mm256_extracti128_si256(r0,1));
            dst += s1;
        }
        // tail loop
        length &= 31;
        if (length != 0) {
            const uint64_t bitmap = 0xFFFFFFFEFFFFC1FF;
            do {
                uint64_t c = *src++;
                *dst = (uint8_t)c;
                dst += ((bitmap >> c) & 1) | ((c + 0xC0) >> 8);
            } while (--length);
        }
        return (size_t)(dst - ((uint8_t*)dst_void));
    }
    

    For posterity, the 1 KiB version (generating the table is left as an exercise for the reader).

    static const uint64_t table[128] __attribute__((aligned(64))) = {
        0x0706050403020100, 0x0007060504030201, ..., 0x0605040302010700, 0x0605040302010007 
    };
    const __m128i mask_01 = _mm_set1_epi8( 0x01 );
    
    __m128i vector0 = _mm_loadu_si128((__m128i*)src);
    __m128i vector1 = _mm_shuffle_epi32( vector0, 0x0E );
    
    __m128i bytemask0 = _mm_cmpeq_epi8( ???, vector0); // detect bytes to omit
    
    uint32_t bitmask0 = _mm_movemask_epi8(bytemask0) & 0x7F7F;
    __m128i hsum = _mm_sad_epu8(_mm_add_epi8(bytemask0, mask_01), _mm_setzero_si128());
    
    vector0 = _mm_shuffle_epi8(vector0, _mm_loadl_epi64((__m128i*) &table[(uint8_t)bitmask0]));
    _mm_storel_epi64((__m128i*)dst, vector0);
    dst += (uint32_t)_mm_cvtsi128_si32(hsum);
    
    vector1 = _mm_shuffle_epi8(vector1, _mm_loadl_epi64((__m128i*) &table[bitmask0 >> 8]));
    _mm_storel_epi64((__m128i*)dst, vector1);
    dst += (uint32_t)_mm_cvtsi128_si32(_mm_unpackhi_epi64(hsum, hsum));
    

    https://github.com/InstLatx64/AVX512_VPCOMPRESSB_Emu has some benchmarks.