Search code examples
cmatrixoptimizationssesimd

A better 8x8 bytes matrix transpose with SSE?


I found this post that explains how to transpose an 8x8 bytes matrix with 24 operations, and a few scrolls later there's the code that implements the transpose. However, this method does not exploit the fact that we can block the 8x8 transpose into four 4x4 transposes, and each one can be done in one shuffle instruction only (this post is the reference). So I came out with this solution:

__m128i transpose4x4mask = _mm_set_epi8(15, 11, 7, 3, 14, 10, 6, 2, 13,  9, 5, 1, 12,  8, 4, 0);
__m128i shuffle8x8Mask = _mm_setr_epi8(0, 1, 2, 3, 8, 9, 10, 11, 4,  5, 6, 7, 12,  13, 14, 15);

void TransposeBlock8x8(uint8_t *src, uint8_t *dst, int srcStride, int dstStride) {
    __m128i load0 = _mm_set_epi64x(*(uint64_t*)(src + 1 * srcStride), *(uint64_t*)(src + 0 * srcStride));
    __m128i load1 = _mm_set_epi64x(*(uint64_t*)(src + 3 * srcStride), *(uint64_t*)(src + 2 * srcStride));
    __m128i load2 = _mm_set_epi64x(*(uint64_t*)(src + 5 * srcStride), *(uint64_t*)(src + 4 * srcStride));
    __m128i load3 = _mm_set_epi64x(*(uint64_t*)(src + 7 * srcStride), *(uint64_t*)(src + 6 * srcStride));

    __m128i shuffle0 = _mm_shuffle_epi8(load0, shuffle8x8Mask);
    __m128i shuffle1 = _mm_shuffle_epi8(load1, shuffle8x8Mask);
    __m128i shuffle2 = _mm_shuffle_epi8(load2, shuffle8x8Mask);
    __m128i shuffle3 = _mm_shuffle_epi8(load3, shuffle8x8Mask);

    __m128i block0 = _mm_unpacklo_epi64(shuffle0, shuffle1);
    __m128i block1 = _mm_unpackhi_epi64(shuffle0, shuffle1);
    __m128i block2 = _mm_unpacklo_epi64(shuffle2, shuffle3);
    __m128i block3 = _mm_unpackhi_epi64(shuffle2, shuffle3);

    __m128i transposed0 = _mm_shuffle_epi8(block0, transpose4x4mask);   
    __m128i transposed1 = _mm_shuffle_epi8(block1, transpose4x4mask);   
    __m128i transposed2 = _mm_shuffle_epi8(block2, transpose4x4mask);   
    __m128i transposed3 = _mm_shuffle_epi8(block3, transpose4x4mask);   

    __m128i store0 = _mm_unpacklo_epi32(transposed0, transposed2);
    __m128i store1 = _mm_unpackhi_epi32(transposed0, transposed2);
    __m128i store2 = _mm_unpacklo_epi32(transposed1, transposed3);
    __m128i store3 = _mm_unpackhi_epi32(transposed1, transposed3);

    *((uint64_t*)(dst + 0 * dstStride)) = _mm_extract_epi64(store0, 0);
    *((uint64_t*)(dst + 1 * dstStride)) = _mm_extract_epi64(store0, 1);
    *((uint64_t*)(dst + 2 * dstStride)) = _mm_extract_epi64(store1, 0);
    *((uint64_t*)(dst + 3 * dstStride)) = _mm_extract_epi64(store1, 1);
    *((uint64_t*)(dst + 4 * dstStride)) = _mm_extract_epi64(store2, 0);
    *((uint64_t*)(dst + 5 * dstStride)) = _mm_extract_epi64(store2, 1);
    *((uint64_t*)(dst + 6 * dstStride)) = _mm_extract_epi64(store3, 0);
    *((uint64_t*)(dst + 7 * dstStride)) = _mm_extract_epi64(store3, 1);
}

Excluding load/store operations this procedure consists of only 16 instructions instead of 24.

What am I missing?


Solution

  • Apart from the loads, stores and pinsrq-s to read from and write to memory, with possibly a stride not equal to 8 bytes, you can do the transpose with only 12 instructions (this code can easily be used in combination with Z boson's test code):

    void tran8x8b_SSE_v2(char *A, char *B) {
      __m128i pshufbcnst = _mm_set_epi8(15,11,7,3, 14,10,6,2, 13,9,5,1, 12,8,4,0);
    
      __m128i B0, B1, B2, B3, T0, T1, T2, T3;
      B0 = _mm_loadu_si128((__m128i*)&A[ 0]);
      B1 = _mm_loadu_si128((__m128i*)&A[16]);
      B2 = _mm_loadu_si128((__m128i*)&A[32]);
      B3 = _mm_loadu_si128((__m128i*)&A[48]);
    
    
      T0 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(B0),_mm_castsi128_ps(B1),0b10001000));
      T1 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(B2),_mm_castsi128_ps(B3),0b10001000));
      T2 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(B0),_mm_castsi128_ps(B1),0b11011101));
      T3 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(B2),_mm_castsi128_ps(B3),0b11011101));
    
      B0 = _mm_shuffle_epi8(T0,pshufbcnst);
      B1 = _mm_shuffle_epi8(T1,pshufbcnst);
      B2 = _mm_shuffle_epi8(T2,pshufbcnst);
      B3 = _mm_shuffle_epi8(T3,pshufbcnst);
    
      T0 = _mm_unpacklo_epi32(B0,B1);
      T1 = _mm_unpackhi_epi32(B0,B1);
      T2 = _mm_unpacklo_epi32(B2,B3);
      T3 = _mm_unpackhi_epi32(B2,B3);
    
      _mm_storeu_si128((__m128i*)&B[ 0], T0);
      _mm_storeu_si128((__m128i*)&B[16], T1);
      _mm_storeu_si128((__m128i*)&B[32], T2);
      _mm_storeu_si128((__m128i*)&B[48], T3);
    }
    


    Here we use the 32 bit floating point shuffle which is more flexible than the epi32 shuffle. The casts do not generate extra instructions (code generated with gcc 5.4):

    tran8x8b_SSE_v2:
    .LFB4885:
        .cfi_startproc
        vmovdqu 48(%rdi), %xmm5
        vmovdqu 32(%rdi), %xmm2
        vmovdqu 16(%rdi), %xmm0
        vmovdqu (%rdi), %xmm1
        vshufps $136, %xmm5, %xmm2, %xmm4
        vshufps $221, %xmm5, %xmm2, %xmm2
        vmovdqa .LC6(%rip), %xmm5
        vshufps $136, %xmm0, %xmm1, %xmm3
        vshufps $221, %xmm0, %xmm1, %xmm1
        vpshufb %xmm5, %xmm3, %xmm3
        vpshufb %xmm5, %xmm1, %xmm0
        vpshufb %xmm5, %xmm4, %xmm4
        vpshufb %xmm5, %xmm2, %xmm1
        vpunpckldq  %xmm4, %xmm3, %xmm5
        vpunpckldq  %xmm1, %xmm0, %xmm2
        vpunpckhdq  %xmm4, %xmm3, %xmm3
        vpunpckhdq  %xmm1, %xmm0, %xmm0
        vmovups %xmm5, (%rsi)
        vmovups %xmm3, 16(%rsi)
        vmovups %xmm2, 32(%rsi)
        vmovups %xmm0, 48(%rsi)
        ret
        .cfi_endproc
    



    On some, but not all, older cpus there might be a small bypass delay (between 0 and 2 cycles) for moving data between the integer and the floating point units. This increases the latency of the function, but it does not necessarily affect the throughput of the code.

    A simple latency test with 1e9 tranpositions:

      for (int i=0;i<500000000;i++){
         tran8x8b_SSE(A,C);
         tran8x8b_SSE(C,A);
      }
      print8x8b(A);
    

    This takes about 5.5 seconds (19.7e9 cycles) with tran8x8b_SSE and 4.5 seconds (16.0e9 cycles) with tran8x8b_SSE_v2 (Intel core i5-6500). Note that the load and stores were not eliminated by the compiler, although the functions were inlined in the for loop.


    Update: AVX2-128 / SSE 4.1 solution with blends.

    The 'shuffles' (unpack, shuffle) are handled by port 5, with 1 instruction per cpu cycle on modern cpus. Sometimes it pays off to replace one 'shuffle' with two blends. On Skylake the 32 bit blend instructions can run on either port 0, 1 or 5.

    Unfortunately, _mm_blend_epi32 is only AVX2-128. An efficient SSE 4.1 alternative is _mm_blend_ps in combination with a few casts (which are usually free). The 12 'shuffles' are replaced by 8 shuffles in combination with 8 blends.

    The simple latency test now runs in about 3.6 seconds (13e9 cpu cycles), which is 18 % faster than the results with tran8x8b_SSE_v2.

    Code:

    /* AVX2-128 version, sse 4.1 version see ---------------->       SSE 4.1 version of tran8x8b_AVX2_128()                                                              */
    void tran8x8b_AVX2_128(char *A, char *B) {                   /*  void tran8x8b_SSE4_1(char *A, char *B) {                                                            */                                    
      __m128i pshufbcnst_0 = _mm_set_epi8(15, 7,11, 3,  
                   13, 5, 9, 1,  14, 6,10, 2,  12, 4, 8, 0);     /*    __m128i pshufbcnst_0 = _mm_set_epi8(15, 7,11, 3,  13, 5, 9, 1,  14, 6,10, 2,  12, 4, 8, 0);       */                                    
      __m128i pshufbcnst_1 = _mm_set_epi8(13, 5, 9, 1,  
                   15, 7,11, 3,  12, 4, 8, 0,  14, 6,10, 2);     /*    __m128i pshufbcnst_1 = _mm_set_epi8(13, 5, 9, 1,  15, 7,11, 3,  12, 4, 8, 0,  14, 6,10, 2);       */                                    
      __m128i pshufbcnst_2 = _mm_set_epi8(11, 3,15, 7,  
                    9, 1,13, 5,  10, 2,14, 6,   8, 0,12, 4);     /*    __m128i pshufbcnst_2 = _mm_set_epi8(11, 3,15, 7,   9, 1,13, 5,  10, 2,14, 6,   8, 0,12, 4);       */                                    
      __m128i pshufbcnst_3 = _mm_set_epi8( 9, 1,13, 5,  
                   11, 3,15, 7,   8, 0,12, 4,  10, 2,14, 6);     /*    __m128i pshufbcnst_3 = _mm_set_epi8( 9, 1,13, 5,  11, 3,15, 7,   8, 0,12, 4,  10, 2,14, 6);       */                                    
      __m128i B0, B1, B2, B3, T0, T1, T2, T3;                    /*    __m128 B0, B1, B2, B3, T0, T1, T2, T3;                                                            */                                    
                                                                 /*                                                                                                      */                                    
      B0 = _mm_loadu_si128((__m128i*)&A[ 0]);                    /*    B0 = _mm_loadu_ps((float*)&A[ 0]);                                                                */                                    
      B1 = _mm_loadu_si128((__m128i*)&A[16]);                    /*    B1 = _mm_loadu_ps((float*)&A[16]);                                                                */                                    
      B2 = _mm_loadu_si128((__m128i*)&A[32]);                    /*    B2 = _mm_loadu_ps((float*)&A[32]);                                                                */                                    
      B3 = _mm_loadu_si128((__m128i*)&A[48]);                    /*    B3 = _mm_loadu_ps((float*)&A[48]);                                                                */                                    
                                                                 /*                                                                                                      */                                    
      B1 = _mm_shuffle_epi32(B1,0b10110001);                     /*    B1 = _mm_shuffle_ps(B1,B1,0b10110001);                                                            */                                    
      B3 = _mm_shuffle_epi32(B3,0b10110001);                     /*    B3 = _mm_shuffle_ps(B3,B3,0b10110001);                                                            */                                    
      T0 = _mm_blend_epi32(B0,B1,0b1010);                        /*    T0 = _mm_blend_ps(B0,B1,0b1010);                                                                  */                                    
      T1 = _mm_blend_epi32(B2,B3,0b1010);                        /*    T1 = _mm_blend_ps(B2,B3,0b1010);                                                                  */                                    
      T2 = _mm_blend_epi32(B0,B1,0b0101);                        /*    T2 = _mm_blend_ps(B0,B1,0b0101);                                                                  */                                    
      T3 = _mm_blend_epi32(B2,B3,0b0101);                        /*    T3 = _mm_blend_ps(B2,B3,0b0101);                                                                  */                                    
                                                                 /*                                                                                                      */                                    
      B0 = _mm_shuffle_epi8(T0,pshufbcnst_0);                    /*    B0 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(T0),pshufbcnst_0));                       */                                    
      B1 = _mm_shuffle_epi8(T1,pshufbcnst_1);                    /*    B1 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(T1),pshufbcnst_1));                       */                                    
      B2 = _mm_shuffle_epi8(T2,pshufbcnst_2);                    /*    B2 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(T2),pshufbcnst_2));                       */                                    
      B3 = _mm_shuffle_epi8(T3,pshufbcnst_3);                    /*    B3 = _mm_castsi128_ps(_mm_shuffle_epi8(_mm_castps_si128(T3),pshufbcnst_3));                       */                                    
                                                                 /*                                                                                                      */                                    
      T0 = _mm_blend_epi32(B0,B1,0b1010);                        /*    T0 = _mm_blend_ps(B0,B1,0b1010);                                                                  */                                    
      T1 = _mm_blend_epi32(B0,B1,0b0101);                        /*    T1 = _mm_blend_ps(B0,B1,0b0101);                                                                  */                                    
      T2 = _mm_blend_epi32(B2,B3,0b1010);                        /*    T2 = _mm_blend_ps(B2,B3,0b1010);                                                                  */                                    
      T3 = _mm_blend_epi32(B2,B3,0b0101);                        /*    T3 = _mm_blend_ps(B2,B3,0b0101);                                                                  */                                    
      T1 = _mm_shuffle_epi32(T1,0b10110001);                     /*    T1 = _mm_shuffle_ps(T1,T1,0b10110001);                                                            */                                    
      T3 = _mm_shuffle_epi32(T3,0b10110001);                     /*    T3 = _mm_shuffle_ps(T3,T3,0b10110001);                                                            */                                    
                                                                 /*                                                                                                      */                                    
      _mm_storeu_si128((__m128i*)&B[ 0], T0);                    /*    _mm_storeu_ps((float*)&B[ 0], T0);                                                                */                                    
      _mm_storeu_si128((__m128i*)&B[16], T1);                    /*    _mm_storeu_ps((float*)&B[16], T1);                                                                */                                    
      _mm_storeu_si128((__m128i*)&B[32], T2);                    /*    _mm_storeu_ps((float*)&B[32], T2);                                                                */                                    
      _mm_storeu_si128((__m128i*)&B[48], T3);                    /*    _mm_storeu_ps((float*)&B[48], T3);                                                                */                                    
    }                                                            /*  }                                                                                                   */