Search code examples

GEMM kernel implemented using AVX2 is faster than AVX2/FMA on a Zen 2 CPU

I have tried speeding up a toy GEMM implementation. I deal with blocks of 32x32 doubles for which I need an optimized MM kernel. I have access to AVX2 and FMA.

I have two codes (in ASM, I apologies for the crudeness of the formatting) defined below, one is making use of AVX2 features, the other uses FMA.

Without going into micro benchmarks, I would like to try to develop an understanding (theoretical) of why the AVX2 implementation is 1.11x faster than the FMA version. And possibly how to improve both versions.

The codes below are for a 3000x3000 MM of doubles and the kernels are implemented using the classical, naive MM with an interchanged deepest loop. I'm using a Ryzen 3700x/Zen 2 as development CPU.

I have not tried unrolling aggressively, in fear that the CPU might run out of physical registers.

AVX2 32x32 MM kernel:

Block 82:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 83:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vmulpd ymm10, ymm8, ymmword ptr [rax-0xa0]
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x80]
    vmulpd ymm9, ymm8, ymmword ptr [rax-0xe0]
    vmulpd ymm12, ymm8, ymmword ptr [rax-0xc0]
    vaddpd ymm2, ymm10, ymm2    
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x60]
    vaddpd ymm3, ymm11, ymm3    
    vmulpd ymm11, ymm8, ymmword ptr [rax-0x40]
    vaddpd ymm0, ymm9, ymm0   
    vaddpd ymm1, ymm12, ymm1
    vaddpd ymm4, ymm10, ymm4
    vmulpd ymm10, ymm8, ymmword ptr [rax-0x20]
    vmulpd ymm8, ymm8, ymmword ptr [rax]       
    vaddpd ymm5, ymm11, ymm5    
    add rax, 0x5dc0 
    vaddpd ymm6, ymm10, ymm6
    vaddpd ymm7, ymm8, ymm7 
    cmp r13, 0x20
    jnz 0x140004530 <Block 83>
Block 84:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400044d0 <Block 82>

AVX2/FMA 32x32 MM kernel:

Block 80:
    imul r12, r15, 0xbb8
    mov rax, r11
    mov r13d, 0x0
    vmovupd ymm0, ymmword ptr [rdi+r12*8]
    vmovupd ymm1, ymmword ptr [rdi+r12*8+0x20]
    vmovupd ymm2, ymmword ptr [rdi+r12*8+0x40]
    vmovupd ymm3, ymmword ptr [rdi+r12*8+0x60]
    vmovupd ymm4, ymmword ptr [rdi+r12*8+0x80]
    vmovupd ymm5, ymmword ptr [rdi+r12*8+0xa0]
    vmovupd ymm6, ymmword ptr [rdi+r12*8+0xc0]
    vmovupd ymm7, ymmword ptr [rdi+r12*8+0xe0]
    lea r14, ptr [r12+0x4]
    nop dword ptr [rax+rax*1], eax
Block 81:
    vbroadcastsd ymm8, qword ptr [rcx+r13*8]
    inc r13
    vfmadd231pd ymm0, ymm8, ymmword ptr [rax-0xe0]
    vfmadd231pd ymm1, ymm8, ymmword ptr [rax-0xc0]
    vfmadd231pd ymm2, ymm8, ymmword ptr [rax-0xa0]
    vfmadd231pd ymm3, ymm8, ymmword ptr [rax-0x80]
    vfmadd231pd ymm4, ymm8, ymmword ptr [rax-0x60]
    vfmadd231pd ymm5, ymm8, ymmword ptr [rax-0x40]
    vfmadd231pd ymm6, ymm8, ymmword ptr [rax-0x20]
    vfmadd231pd ymm7, ymm8, ymmword ptr [rax]
    add rax, 0x5dc0 
    cmp r13, 0x20   
    jnz 0x140004450
Block 82:
    inc r15
    add rcx, 0x5dc0
    vmovupd ymmword ptr [rdi+r12*8], ymm0
    vmovupd ymmword ptr [rdi+r14*8], ymm1
    vmovupd ymmword ptr [rdi+r12*8+0x40], ymm2
    vmovupd ymmword ptr [rdi+r12*8+0x60], ymm3
    vmovupd ymmword ptr [rdi+r12*8+0x80], ymm4
    vmovupd ymmword ptr [rdi+r12*8+0xa0], ymm5
    vmovupd ymmword ptr [rdi+r12*8+0xc0], ymm6
    vmovupd ymmword ptr [rdi+r12*8+0xe0], ymm7
    cmp r15, 0x20
    jnz 0x1400043f0 <Block 80>


  • Zen2 has 3 cycle latency for vaddpd, 5 cycle latency for vfma...pd. (

    Your code with 8 accumulators has enough ILP that you'd expect close to two FMA per clock, about 8 per 5 clocks (if there aren't other bottlenecks) which is a bit less than the 10/5 theoretical max.

    vaddpd and vmulpd actually run on different ports on Zen2 (unlike Intel), port FP2/3 and FP0/1 respectively, so it can in theory sustain 2/clock vaddpd and vmulpd. Since the latency of the loop-carried dependency is shorter, 8 accumulators are enough to hide the vaddpd latency if scheduling doesn't let one dep chain get behind. (But at least multiplies aren't stealing cycles from it.)

    Zen2's front-end is 5 instructions wide (or 6 uops if there are any multi-uop instructions), and it can decode memory-source instructions as a single uop. So it might well be doing 2/clock each multiply and add with the non-FMA version.

    If you can unroll by 10 or 12, that might hide enough FMA latency and make it equal to the non-FMA version, but with less power consumption and more SMT-friendly to code running on the other logical core. (10 = 5 x 2 would be just barely enough, which means any scheduling imperfections lose progress on a dep chain which is on the critical path. See Why does mulss take only 3 cycles on Haswell, different from Agner's instruction tables? (Unrolling FP loops with multiple accumulators) for some testing on Intel.)

    (By comparison, Intel Skylake runs vaddpd/vmulpd on the same ports with the same latency as vfma...pd, all with 4c latency, 0.5c throughput.)

    I didn't look at your code super carefully, but 10 YMM vectors might be a tradeoff between touching two pairs of cache lines vs. touching 5 total lines, which might be worse if a spatial prefetcher tries to complete an aligned pair. Or might be fine. 12 YMM vectors would be three pairs, which should be fine.

    Depending on matrix size, out-of-order exec may be able to overlap inner loop dep chains between separate iterations of the outer loop, especially if the loop exit condition can execute sooner and resolve the mispredict (if there is one) while FP work is still in flight. That's an advantage to having fewer total uops for the same work, favouring FMA.