Search code examples
c++ctransposematrix-multiplicationsse

How do I more efficiently multiply A*B^T or A^T*B^T (T for transpose) matrices using SSE?


I keep beating myself over the head with this. I have an SSE-based algorithm for multiplying matrix A by matrix B. I need to also implement the operations for where A, B, or both are transposed. I did a naive implementation of it, the 4x4 matrix code represented below (which is pretty standard SSE operations, I think), but the A*B^T operation takes about as twice as long as A*B. The ATLAS implementation returns similar values for A*B, and nearly identical results for multiplying by a transpose, which suggests to me that there is an efficient way to do this.

MM-Multiplication:

m1 = (mat1.m_>>2)<<2;
n2 = (mat2.n_>>2)<<2;
n  = (mat1.n_>>2)<<2;

for (k=0; k<n; k+=4) {
  for (i=0; i<m1; i+=4) {
    // fetch: get 4x4 matrix from mat1
    // row-major storage, so get 4 rows
    Float* a0 = mat1.el_[i]+k;
    Float* a1 = mat1.el_[i+1]+k;
    Float* a2 = mat1.el_[i+2]+k;
    Float* a3 = mat1.el_[i+3]+k;

    for (j=0; j<n2; j+=4) {
      // fetch: get 4x4 matrix from mat2
      // row-major storage, so get 4 rows
      Float* b0 = mat2.el_[k]+j;
      Float* b1 = mat2.el_[k+1]+j;
      Float* b2 = mat2.el_[k+2]+j;
      Float* b3 = mat2.el_[k+3]+j;

      __m128 b0r = _mm_loadu_ps(b0);
      __m128 b1r = _mm_loadu_ps(b1);
      __m128 b2r = _mm_loadu_ps(b2);
      __m128 b3r = _mm_loadu_ps(b3);

      {  // first row of result += first row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a0+0), b0r), _mm_mul_ps(_mm_load_ps1(a0+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a0+2), b2r), _mm_mul_ps(_mm_load_ps1(a0+3), b3r));
        Float* c0 = this->el_[i]+j;
        _mm_storeu_ps(c0, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c0)));
      }

      { // second row of result += second row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a1+0), b0r), _mm_mul_ps(_mm_load_ps1(a1+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a1+2), b2r), _mm_mul_ps(_mm_load_ps1(a1+3), b3r));
        Float* c1 = this->el_[i+1]+j;
        _mm_storeu_ps(c1, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c1)));
      }

      { // third row of result += third row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a2+0), b0r), _mm_mul_ps(_mm_load_ps1(a2+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a2+2), b2r), _mm_mul_ps(_mm_load_ps1(a2+3), b3r));
        Float* c2 = this->el_[i+2]+j;
        _mm_storeu_ps(c2, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c2)));
      }

      { // fourth row of result += fourth row of mat1 * 4x4 of mat2
        __m128 cX1 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a3+0), b0r), _mm_mul_ps(_mm_load_ps1(a3+1), b1r));
        __m128 cX2 = _mm_add_ps(_mm_mul_ps(_mm_load_ps1(a3+2), b2r), _mm_mul_ps(_mm_load_ps1(a3+3), b3r));
        Float* c3 = this->el_[i+3]+j;
        _mm_storeu_ps(c3, _mm_add_ps(_mm_add_ps(cX1, cX2), _mm_loadu_ps(c3)));
      }
  }
// Code omitted to handle remaining rows and columns
}

For the MT multiplication (matrix multiplied by transpose matrix), I stead stored b0r to b3r with the following commands and changed the loop variables appropriately:

__m128 b0r = _mm_set_ps(b3[0], b2[0], b1[0], b0[0]);
__m128 b1r = _mm_set_ps(b3[1], b2[1], b1[1], b0[1]);
__m128 b2r = _mm_set_ps(b3[2], b2[2], b1[2], b0[2]);
__m128 b3r = _mm_set_ps(b3[3], b2[3], b1[3], b0[3]);

I suspect that the slowdown is partly due to the difference between pulling in a row at a time and having to store 4 values each time to get the column, but I feel like the other way of going about this, pulling in rows of B and then multiplying by the column of As, will just shift the cost over to storing 4 columns of results.

I have also tried pulling in B's rows as rows and then using _MM_TRANSPOSE4_PS(b0r, b1r, b2r, b3r); to do the transposition (I thought there might be some additional optimizations in that macro), but there's no real improvement.

On the surface, I feel like this should be faster... the dot products involved would be a row by a row, which seems inherently more efficient, but trying to do the dot products straight up just results in having to do the same thing to store the results.

What am I missing here?

Added: Just to clarify, I'm trying to not transpose the matrices. I'd prefer to iterate along them. The problem, as best I can tell, is that the _mm_set_ps command is much slower than _mm_load_ps.

I also tried a variation where I stored the 4 rows of the A matrix and then replaced the 4 curly-bracketed segments containing 1 load, 4 multiplies, and 2 adds with 4 multiply instructions and 3 hadds, but to little avail. The timing stays the same (and yes, I tried it with a debug statement to verify that the code had changed in my test compile. Said debug statement was removed before profiling, of course):

    {  // first row of result += first row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a0r, b0r), _mm_mul_ps(a0r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a0r, b2r), _mm_mul_ps(a0r, b3r));
      Float* c0 = this->el_[i]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    }

    { // second row of result += second row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a1r, b0r), _mm_mul_ps(a1r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a1r, b2r), _mm_mul_ps(a1r, b3r));
      Float* c0 = this->el_[i+1]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    }

    { // third row of result += third row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a2r, b0r), _mm_mul_ps(a2r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a2r, b2r), _mm_mul_ps(a2r, b3r));
      Float* c0 = this->el_[i+2]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    }

    { // fourth row of result += fourth row of mat1 * 4x4 of mat2
      __m128 cX1 = _mm_hadd_ps(_mm_mul_ps(a3r, b0r), _mm_mul_ps(a3r, b1r));
      __m128 cX2 = _mm_hadd_ps(_mm_mul_ps(a3r, b2r), _mm_mul_ps(a3r, b3r));
      Float* c0 = this->el_[i+3]+j;
      _mm_storeu_ps(c0, _mm_add_ps(_mm_hadd_ps(cX1, cX2), _mm_loadu_ps(c0)));
    }

Update: Right, and moving the loading of the rows of a0r to a3r into the curly braces in an attempt to avoid register thrashing failed as well.


Solution

  • I think this is one few cases where horizontal add is useful. You want C = AB^T but B is not stored in memory as the transpose. That's the problem. It's store like an AoS instead of a SoA. In this case taking the transpose of B and doing vertical add is slower than using horizontal add I think. This is at least true for Matrixvector Efficient 4x4 matrix vector multiplication with SSE: horizontal add and dot product - what's the point?. In the code below the function m4x4 is non SSE 4x4 matrix-product, m4x4_vec uses SSE, m4x4T does C=AB^T without SSE, and m4x4T_vec does C=AB^T usisg SSE. The last one is the one you want I think.

    Note: for larger matrices I would not use this method. In that case it's faster to take the transpose first and use vertical add (with SSE/AVX you do something more complicated, you transpose strips with the SSE/AVX width). That's because the transpose goes as O(n^2) and the matrix product goes as O(n^3) so for large matrices the transpose is insignificant. However, for 4x4 the transpose is significant so horizontal add wins.

    Edit: I misunderstood what you wanted. You want C = (AB)^T. That should be just as fast as (AB) and the code is nearly the same you basically just swap the roles of A and B.
    We can write the math as follows:

    C = A*B in Einstein notation is C_i,j = A_i,k * B_k,j.  
    Since (A*B)^T = B^T*A^T we can write 
    C = (A*B)^T in Einstein notation is C_i,j = B^T_i,k * A^T_k,j = A_j,k * B_k,i
    

    If you compare the two the only thing that changes is we swap the roles of j and i. I put some code to do this at the end of this answer.

    #include "stdio.h"
    #include <nmmintrin.h>    
    
    void m4x4(const float *A, const float *B, float *C) {
        for(int i=0; i<4; i++) {
            for(int j=0; j<4; j++) {
                float sum = 0.0f;
                for(int k=0; k<4; k++) {
                    sum += A[i*4+k]*B[k*4+j];
                }
                C[i*4 + j] = sum;
            }
        }
    }
    
    void m4x4T(const float *A, const float *B, float *C) {
        for(int i=0; i<4; i++) {
            for(int j=0; j<4; j++) {
                float sum = 0.0f;
                for(int k=0; k<4; k++) {
                    sum += A[i*4+k]*B[j*4+k];
                }
                C[i*4 + j] = sum;
            }
        }
    }
    
    void m4x4_vec(const float *A, const float *B, float *C) {
        __m128 Brow[4], Mrow[4];
        for(int i=0; i<4; i++) {
            Brow[i] = _mm_load_ps(&B[4*i]);
        }
    
        for(int i=0; i<4; i++) {
            Mrow[i] = _mm_set1_ps(0.0f);
            for(int j=0; j<4; j++) {
                __m128 a = _mm_set1_ps(A[4*i +j]);
                Mrow[i] = _mm_add_ps(Mrow[i], _mm_mul_ps(a, Brow[j]));
            }
        }
        for(int i=0; i<4; i++) {
            _mm_store_ps(&C[4*i], Mrow[i]);
        }
    }
    
    void m4x4T_vec(const float *A, const float *B, float *C) {
        __m128 Arow[4], Brow[4], Mrow[4];
        for(int i=0; i<4; i++) {
            Arow[i] = _mm_load_ps(&A[4*i]);
            Brow[i] = _mm_load_ps(&B[4*i]);
        }
    
        for(int i=0; i<4; i++) {
            __m128 prod[4];
            for(int j=0; j<4; j++) {
                prod[j] =  _mm_mul_ps(Arow[i], Brow[j]);
            }
            Mrow[i] = _mm_hadd_ps(_mm_hadd_ps(prod[0], prod[1]), _mm_hadd_ps(prod[2], prod[3]));    
        }
        for(int i=0; i<4; i++) {
            _mm_store_ps(&C[4*i], Mrow[i]);
        }
    
    }
    
    float compare_4x4(const float* A, const float*B) {
        float diff = 0.0f;
        for(int i=0; i<4; i++) {
            for(int j=0; j<4; j++) {
                diff += A[i*4 +j] - B[i*4+j];
                printf("A %f, B %f\n", A[i*4 +j], B[i*4 +j]);
            }
        }
        return diff;    
    }
    
    int main() {
        float *A = (float*)_mm_malloc(sizeof(float)*16,16);
        float *B = (float*)_mm_malloc(sizeof(float)*16,16);
        float *C1 = (float*)_mm_malloc(sizeof(float)*16,16);
        float *C2 = (float*)_mm_malloc(sizeof(float)*16,16);
    
        for(int i=0; i<4; i++) {
            for(int j=0; j<4; j++) {
                A[i*4 +j] = i*4+j;
                B[i*4 +j] = i*4+j;
                C1[i*4 +j] = 0.0f;
                C2[i*4 +j] = 0.0f;
            }
        }
        m4x4T(A, B, C1);
        m4x4T_vec(A, B, C2);
        printf("compare %f\n", compare_4x4(C1,C2));
    
    }
    

    Edit:

    Here is are the scalar and SSE function that do C = (AB)^T. They should be just as fast as their AB versions.

    void m4x4TT(const float *A, const float *B, float *C) {
        for(int i=0; i<4; i++) {
            for(int j=0; j<4; j++) {
                float sum = 0.0f;
                for(int k=0; k<4; k++) {
                    sum += A[j*4+k]*B[k*4+i];
                }
                C[i*4 + j] = sum;
            }
        }
    }
    
    void m4x4TT_vec(const float *A, const float *B, float *C) {
        __m128 Arow[4], Crow[4];
        for(int i=0; i<4; i++) {
            Arow[i] = _mm_load_ps(&A[4*i]);
        }
    
        for(int i=0; i<4; i++) {
            Crow[i] = _mm_set1_ps(0.0f);
            for(int j=0; j<4; j++) {
                __m128 a = _mm_set1_ps(B[4*i +j]);
                Crow[i] = _mm_add_ps(Crow[i], _mm_mul_ps(a, Arow[j]));
            }
        }
    
        for(int i=0; i<4; i++) {
            _mm_store_ps(&C[4*i], Crow[i]);
        }
    }