Search code examples
cmatrix-multiplicationintrinsicsavx

Using SIMD To Parallelize Matrix Multiplication For A 4x4, Row-Major Matrix


I am currently facing an extremely hard time trying to parallelize a 4x4 matrix-multiplication algorithm. I am trying to make a library to use in a minimal raytracer project for school, so I'm trying to make its conventions as user-friendly as possible, which is why I chose to store matrices in a row-major order, as that seemed more intuitive for most people I asked. However, this poses an issue, as a lot of parallelization strategies require column extraction (for dot-product simplifications), and I'm facing a very hard time trying to even think in parallel... Here's my current matrix-multiplication code (with some missing parts) for computing the matrix product of two 4x4 matrices (row-major).

static inline __m128    _extract_column_ps4(const t_mat4s *in, int c)
{
    return (_mm_set_ps(in->a[3][c], in->a[2][c], in->a[1][c], in->a[0][c]));
}

static inline __m128    compute_row(const __m128 *in1, const t_mat4s *in2,
                            int r)
{
    __m128  col[4];
    __m128  mul[4];
    __m128  res;

    col[0] = _extract_column_ps4(in2, 0);
    col[1] = _extract_column_ps4(in2, 1);
    col[2] = _extract_column_ps4(in2, 2);
    col[3] = _extract_column_ps4(in2, 3);
    mul[0] = _mm_mul_ps(in1[r], col[0]);
    mul[1] = _mm_mul_ps(in1[r], col[1]);
    mul[2] = _mm_mul_ps(in1[r], col[2]);
    mul[3] = _mm_mul_ps(in1[r], col[3]);
    // ..
    // ..
    return (res);
}

/// @brief computes the cross product of `in1` with `in2`
///        (in that order), and stores the result in the `t_mat4s`
///        pointed-to by `out`.
static inline void  lag_mat4s_cross_mat4s(const t_mat4s in1,
                            const t_mat4s in2, t_mat4s *out)
{
    out->simd[0] = compute_row(in1.simd, &in2, 0);
    out->simd[1] = compute_row(in1.simd, &in2, 1);
    out->simd[2] = compute_row(in1.simd, &in2, 2);
    out->simd[3] = compute_row(in1.simd, &in2, 3);
}

As you can see, I'm not using any hadds, in fact, I'm trying to avoid them, as they are very expensive. Another approach can be to simply do something like:

// Code for computing the result of multiplying two 4x4 row-major matrix of double-precision floats

static inline void  unrolled_cross(const t_mat4d *in, const __m256d col[4],
                        t_mat4d *out)
{
    out->r1.x = lag_dot_product_avx(in->r1.simd, col[0]);
    out->r1.y = lag_dot_product_avx(in->r1.simd, col[1]);
    out->r1.z = lag_dot_product_avx(in->r1.simd, col[2]);
    out->r1.w = lag_dot_product_avx(in->r1.simd, col[3]);
    out->r2.x = lag_dot_product_avx(in->r2.simd, col[0]);
    out->r2.y = lag_dot_product_avx(in->r2.simd, col[1]);
    out->r2.z = lag_dot_product_avx(in->r2.simd, col[2]);
    out->r2.w = lag_dot_product_avx(in->r2.simd, col[3]);
    out->r3.x = lag_dot_product_avx(in->r3.simd, col[0]);
    out->r3.y = lag_dot_product_avx(in->r3.simd, col[1]);
    out->r3.z = lag_dot_product_avx(in->r3.simd, col[2]);
    out->r3.w = lag_dot_product_avx(in->r3.simd, col[3]);
    out->r4.x = lag_dot_product_avx(in->r4.simd, col[0]);
    out->r4.y = lag_dot_product_avx(in->r4.simd, col[1]);
    out->r4.z = lag_dot_product_avx(in->r4.simd, col[2]);
    out->r4.w = lag_dot_product_avx(in->r4.simd, col[3]);
}

void    lag_mat4_cross(const t_mat4d *in, const t_mat4d *in2, t_mat4d *out)
{
    __m256d col[4];

    lag_extract_column4_avx(in2, 0, &col[0]);
    lag_extract_column4_avx(in2, 1, &col[1]);
    lag_extract_column4_avx(in2, 2, &col[2]);
    lag_extract_column4_avx(in2, 3, &col[3]);
    unrolled_cross(in, col, out);
}

Which works perfectly fine to be fair, but I feel like I'm losing on a lot of parallelization here...

I first tried shuffling, but quickly realised it's not gonna work, as my columns aren't contiguous in memory. I also tried using a sequence of _mm_mul_ps's followed by two _mm_add_ps's, which obviously didn't work, because you'd need to add the elements horizontally rather than in parallel (component-wise). I was thinking of using an AVX-256 register to store two sets of 4 packed-singles, but that also gets rather heavy, and my brain got fried trying to even conceptualise it. Any ideas/suggestions? I don't think converting to column-major is an option for me, at this time... Also, could you give me advice on performance the order I store my matrices in (column-major vs row-major). What would be the performance implications? Is any way better than the other? Is it case-by-case? Why?

Edit: I should probably mention that my structures/unions look like this:

typedef union u_vec4s
{
    float       a[4];
    __m128      simd;
    struct
    {
        float   x;
        float   y;
        float   z;
        float   w;
    };
}__attribute((aligned(16))) t_vec4s;

typedef union u_mat4s
{
    float   a[4][4];
    __m128  simd[4];
    struct
    {
        t_vec4s r1;
        t_vec4s r2;
        t_vec4s r3;
        t_vec4s r4;
    };
}__attribute((aligned(16))) t_mat4s;

Edit #2: Here's my revised code:

static inline __m128    _extract_column_ps4(const t_mat4s *in, int c)
{
    return (_mm_set_ps(in->a[3][c], in->a[2][c], in->a[1][c], in->a[0][c]));
}

/// @brief Returns the cross product of a `t_mat4s` with a `t_vec4s`
///        (in that order).
static inline t_vec4s   lag_mat4s_cross_vec4s(const t_mat4s m,
                            const t_vec4s v)
{
    t_vec4s ret;

    ret.x = lag_vec4s_dot_ret(m.r1, v);
    ret.y = lag_vec4s_dot_ret(m.r2, v);
    ret.z = lag_vec4s_dot_ret(m.r3, v);
    ret.w = lag_vec4s_dot_ret(m.r4, v);
    return (ret);
}

static inline __m128    compute_row(const __m128 *in1, const t_mat4s *in2,
                            int r)
{
    t_mat4s cols;
    t_vec4s row;
    __m128  mul[4];
    __m128  add[2];

    cols.simd[0] = _extract_column_ps4(in2, 0);
    cols.simd[1] = _extract_column_ps4(in2, 1);
    cols.simd[2] = _extract_column_ps4(in2, 2);
    cols.simd[3] = _extract_column_ps4(in2, 3);
    row.simd = in1[r];
    return (lag_mat4s_cross_vec4s(cols, row).simd);
}

/// @brief computes the cross product of `in1` with `in2`
///        (in that order), and stores the result in the `t_mat4s`
///        pointed-to by `out`.
static inline void  lag_mat4s_cross_mat4s(const t_mat4s in1,
                            const t_mat4s in2, t_mat4s *out)
{
    out->simd[0] = compute_row(in1.simd, &in2, 0);
    out->simd[1] = compute_row(in1.simd, &in2, 1);
    out->simd[2] = compute_row(in1.simd, &in2, 2);
    out->simd[3] = compute_row(in1.simd, &in2, 3);
}

I'm going for a different strategy, now, wherein I calculate the matrix-vector product, assuming each row is its own vector, multiplied by the same matrix flipped (its columns become its rows). This works, is there a better way to do this..?


Solution

  • Alright. I came to the conclusion that using AVX registers for this is the best approach for a 4x4 row-major matrix (no benchmarks to support my claim), simply due to the fact that broadcast-loads are more efficient apparently. Also decided to avoid spamming _mm_dp_ps in my lag_mat4s_cross_vec4s function. Here's the revised code:

    static inline __m128    _extract_column_ps4(const t_mat4s *in, int c)
    {
        return (_mm_set_ps(in->a[3][c], in->a[2][c], in->a[1][c], in->a[0][c]));
    }
    
    /// @brief Returns the cross product of a `t_mat4s` with a `t_vec4s`
    ///        (in that order).
    static inline t_vec4s   lag_mat4s_cross_vec4s(const t_mat4s m,
                                const t_vec4s v)
    {
        t_vec4s         ret;
        const __m128    mul0 = _mm_mul_ps(m.simd[0], v.simd);
        const __m128    mul1 = _mm_mul_ps(m.simd[1], v.simd);
        const __m128    mul2 = _mm_mul_ps(m.simd[2], v.simd);
        const __m128    mul3 = _mm_mul_ps(m.simd[3], v.simd);
    
    
        ret.simd = _mm_hadd_ps(_mm_hadd_ps(mul0, mul1), _mm_hadd_ps(mul2, mul3));
        return (ret);
    }
    
    //static inline void    lag_mat4s_cross_mat4s(const t_mat4s in1,
    //                          const t_mat4s in2, t_mat4s *out)
    //{
    //  t_mat4s cols;
    
    //  cols.simd[0] = _extract_column_ps4(&in2, 0);
    //  cols.simd[1] = _extract_column_ps4(&in2, 1);
    //  cols.simd[2] = _extract_column_ps4(&in2, 2);
    //  cols.simd[3] = _extract_column_ps4(&in2, 3);
    //  out->r1 = lag_mat4s_cross_vec4s(cols, in1.r1);
    //  out->r2 = lag_mat4s_cross_vec4s(cols, in1.r2);
    //  out->r3 = lag_mat4s_cross_vec4s(cols, in1.r3);
    //  out->r4 = lag_mat4s_cross_vec4s(cols, in1.r4);
    //}
    
    /// @brief computes the cross product of `in1` with `in2`
    ///        (in that order), and stores the result in the `t_mat4s`
    ///        pointed-to by `out`.
    static inline void  lag_mat4s_cross_mat4s(const t_mat4s in1,
                            const t_mat4s in2, t_mat4s *out)
    {
        __m256  a[2];
        __m256  b[2];
        __m256  c[8];
        __m256  t[2];
        __m256  u[2];
    
        t[0] = in1._ymm[0];
        t[1] = in1._ymm[1];
        u[0] = in2._ymm[0];
        u[1] = in2._ymm[1];
        a[0] = _mm256_shuffle_ps(t[0], t[0], _MM_SHUFFLE(0, 0, 0, 0));
        a[1] = _mm256_shuffle_ps(t[1], t[1], _MM_SHUFFLE(0, 0, 0, 0));
        b[0] = _mm256_permute2f128_ps(u[0], u[0], 0x00);
        c[0] = _mm256_mul_ps(a[0], b[0]);
        c[1] = _mm256_mul_ps(a[1], b[0]);
        a[0] = _mm256_shuffle_ps(t[0], t[0], _MM_SHUFFLE(1, 1, 1, 1));
        a[1] = _mm256_shuffle_ps(t[1], t[1], _MM_SHUFFLE(1, 1, 1, 1));
        b[0] = _mm256_permute2f128_ps(u[0], u[0], 0x11);
        c[2] = _mm256_mul_ps(a[0], b[0]);
        c[3] = _mm256_mul_ps(a[1], b[0]);
        a[0] = _mm256_shuffle_ps(t[0], t[0], _MM_SHUFFLE(2, 2, 2, 2));
        a[1] = _mm256_shuffle_ps(t[1], t[1], _MM_SHUFFLE(2, 2, 2, 2));
        b[1] = _mm256_permute2f128_ps(u[1], u[1], 0x00);
        c[4] = _mm256_mul_ps(a[0], b[1]);
        c[5] = _mm256_mul_ps(a[1], b[1]);
        a[0] = _mm256_shuffle_ps(t[0], t[0], _MM_SHUFFLE(3, 3, 3, 3));
        a[1] = _mm256_shuffle_ps(t[1], t[1], _MM_SHUFFLE(3, 3, 3, 3));
        b[1] = _mm256_permute2f128_ps(u[1], u[1], 0x11);
        c[6] = _mm256_mul_ps(a[0], b[1]);
        c[7] = _mm256_mul_ps(a[1], b[1]);
        c[0] = _mm256_add_ps(c[0], c[2]);
        c[4] = _mm256_add_ps(c[4], c[6]);
        c[1] = _mm256_add_ps(c[1], c[3]);
        c[5] = _mm256_add_ps(c[5], c[7]);
        out->_ymm[0] = _mm256_add_ps(c[0], c[4]);
        out->_ymm[1] = _mm256_add_ps(c[1], c[5]);
    }
    

    I stole an implementation from a post that Peter so kindly provided (I didn't just copy-paste. I understand what's going on). I am wondering, however, if performance would be boosted if I align my structs to a 32-byte region instead of 16. Not my concern as of now, but feel free to comment your thoughts; any piece of information and criticism is welcome!