Search code examples
c++cmatrix-multiplicationavx

AVX Intrinsic Clarification, 4x4 Matrix Multiplication Oddities


On paper I drew out the long form of this algorithm, and on paper it should work fine. Am I running into a subtlety with register casting (256/128/256), or did I actually mess up the algorithm structure somewhere?

For convenience, I've put the vanilla code and the AVX code up on the Godbolt viewer so you can see the generated assembly at will.

Standard code https://godbolt.org/g/v47RKH

My AVX Attempt 1: https://godbolt.org/g/oH1DpO

My AVX Attempt 2: https://godbolt.org/g/QFtdKr (Shaved 5 cycles and reduced casting needs, easier to read)

The SSE code oddly enough is using scalar operations, which boggles my mind since that can definitely be accelerated with horizontal broadcasts, muls, and adds. What I'm trying to do is take that concept up one level.

RHS never needs to be changed, but essentially if LHS is {a, b, ..., p}, and LHS is {1, 2, ..., 16}, then we just need 2 registers to hold the 2 halves of RHS and then 2 registers to hold a given row of LHS in the forms {a, a, a, a, b, b, b, b} and {c, c, c, c, d, d, d, d}. This is achieved via 2 broadcasts and a 256/128/256 cast.

We get the intermediate results of

{a*1, a*2, a*3, a*4, b*5, b*6, b*7, b*8} => row[0]

and

{c*9, c*10, c*11, c*12, d*13, d*14, d*15, d*16} => row[1]

And this is unrolled once w.r.t LHS so we generate

{e*1, ... f*8}, {g*9, ... h*16} => row[2], row[3]

Next add r0,r1 and r2,r3 together (keeping r0 and r2 as the current intermediates)

Finally, extract the high half of row[0] to the low half of resHalf, insert the low half of row[2] into the high half of resHalf, insert the high half of row[2] into the high half of row[0], and then add row[0] to resHalf.

By all rights, that should leave us with resHalf[0] equaling the following at the end of iteration i = 0

{a*1 + b*2 + c*3 + d*4, a*5 + b*6 + c*7 + d*8,

a*9 + b*10 + c*11 + d*12, a*13 + b*14 + c*15 + d*16,

e*1 + ... + h*4, e*5 + ... + h*8,

e*9 + ... + h*12, e*13 + ... + h*16}

What my algorithm is producing, however, is the following:

2x {a*1 + c*3, a*5 + c*7, a*9 + c*11, a*13 + c*15},

2x {e*1 + g*3, e*5 + g*7, e*9 + g*11, e*13 + g*15}

And what's scarier still is if I swap rhsHolders[0/1] in the ternary conditional, it doesn't change the results at all. It's as though the compiler is ignoring one of the swaps and adds. Both Clang 4 and GCC 7 do this, so where did I screw up?

EDIT: output should be 4 rows of {10, 26, 42, 58}, but I get {4, 12, 20, 28}


Solution

  • The SSE code oddly enough is using scalar operations, which boggles my mind since that can definitely be accelerated with horizontal broadcasts, muls, and adds.

    Do you mean the compiler-generated assembly code? All the AVX instructions in MatMul() in the clang4.0 and gcc7.1 outputs are operating on ymm vectors. Except for clang's stupid broadcast-loads: it does a scalar load and then a separate AVX2 broadcast instruction, which is extra bad because Intel CPUs handle broadcast-loads very efficiently. The load port itself can do the broadcast so it's a single uop, not even micro-fused with a shuffle uop. But if the source is a register, it needs an ALU uop for the shuffle port.

        vmovss  xmm5, dword ptr [rdi + 24] # xmm5 = mem[0],zero,zero,zero
        vbroadcastss    xmm5, xmm5
    

    clang's actual output (above) is really silly compared to an AVX1 vbroadcastss xmm5, [rdi + 24] like gcc uses.

    In main(), clang does emit scalar operations.

    Since your input matrices are both compile-time constants, the only mystery is why it didn't optimize down to cout << "a long string with the numbers already formatted\n";, or at least optimize away all the math and just have the double results ready for printing. (And yes, they are being converted from float to double in the print loop with vcvtss2sd.)

    It optimizes through some of the intrinsic shuffles and math, doing them at compile time. I guess clang got lost somewhere in the shuffles, and still emitted some math operations. The fact that they're scalar may be an indication that there wasn't much work it didn't do at compile time, but that it didn't reorder things to vectorize it.

    Note that some of the constants do not appear in the source, and they're not in ascending order in memory.

    ...
    .LCPI1_5:
            .long   1092616192              # float 10
    .LCPI1_6:
            .long   1101004800              # float 20
    .LCPI1_7:
            .long   1098907648              # float 16
    ...
    

    It's really nice how clang puts the float value in a comment after the integer representation of the bit pattern.


    or did I actually mess up the algorithm structure somewhere?

    Well, this part of the implementation looks totally bogus. You initialize lowerHalf from rows[j], but then overwrite that value in the next statement.

    __m128 lowerHalf = _mm256_castps256_ps128(rows[j]);
        lowerHalf = _mm_broadcast_ss(&lhs[offset+2*j]);
    

    And then you do a 256b multiply with the upper 128b lane of rows[j] undefined.

        rows[j] = _mm256_castps128_ps256(lowerHalf);
        rows[j] = _mm256_mul_ps(rows[j], (chooser) ? rhsHolders[0] : rhsHolders[1]);
    

    In the asm from gcc and clang, the upper lane is all zero (because they make the obvious choice of using the ymm register last written by the scalar -> xmm broadcast, which implicitly zero-extends to the max vector width). Note that zero-extending isn't guaranteed by _mm256_castps128_ps256. It's very likely unless the __m128 was itself the result of an extract/cast from a 256b or wider vector, but it is undefined. See How to clear the upper 128 bits of __m256 value? for cases where you need a zeroed upper lane in a vector.

    Anyway, this means you'd get the same result from a 128b vector multiply (vmulps xmm, xmm, xmm): the upper 4 elements will all be zero (or NaN) after these instructions

        vbroadcastss    xmm0, DWORD PTR [rdi+40]
        vmulps  ymm0, ymm2, ymm0
    

    This kind of asm output (from gcc7.1) is highly unlikely to be part of a correct matmul implementation.

    I didn't look carefully to figure out what exactly you were trying to do in the source, but I assume it wasn't exactly this.


    And what's scarier still is if I swap rhsHolders[0/1] in the ternary conditional, it doesn't change the results at all. It's as though the compiler is ignoring one of the swaps and adds.

    When changing something in the source doesn't produce the change you expect in the asm output, that's a hint that you probably got the source wrong, and something is optimizing away. Sometimes I copy/paste an intrinsic and forget to change the input variable in the new line, so my function ignores some of its calculation results and uses another one twice.