I was tasked with implementing an optimised matrix multiplication micro-kernel that computes C = A*B
in C++ starting from the following snippet of code. I am getting some counter intuitive behaviour and I would like some help to better understand what is going on.
void mat_mul(double* A, double* B, double* C) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
for (int m = 0; m < M; ++m) {
C[m + n*M] += A[m + k*M] * B[k + n*K];
}
}
}
}
The conditions for the challenge are as follows:
By looking at the order of the loop, it seems like the memory access order already minimises cache misses as we are iterating over the buffers linearly.
My first thought was to try and vectorise the code. This is what I came up with:
void mat_mul(double* A, double* B, double* C) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
__m256d B_v = _mm256_broadcast_sd(B + k + n*K);
for (int m = 0; m < M; m+=4) {
__m256d A_v = _mm256_load_pd(A + m + k*M);
__m256d C_v = _mm256_load_pd(C + m + n*M);
__m256d rax = _mm256_fmadd_pd(A_v, B_v, C_v);
_mm256_store_pd(C + m + n*M, rax);
}
}
}
}
This reaches a maximum of around 23 GFLOPs with M=N=32. When decreasing N and M the performance drops.
After thinking about it for some time, I realised that the L1d cache of the EPYC 7742 is of 32KiB which corresponds to 4096 doubles. Ideally I want all three matrices to be loaded into the L1 cache.
This yields the following optimization problem:
Maximise N>0,M>0 such that N*M + 128*N + 128 * M ≤ 4096.
I noticed that M = 4, N = 8 is a good enough solution that uses power-of-two values.
Given that M=4, I can eliminate the m-loop by keeping a single vector as an accumulation variable.
This idea yielded the following code:
void mat_mul(double* A, double* B, double* C) {
__m256d A_v, B_v;
register __m256d C_v;
double* aPtr;
double* bPtr;
double* cPtr;
for (int n = 0; n < N; ++n) {
cPtr = C + n*M;
C_v = _mm256_load_pd(cPtr);
for (int k = 0; k < K; ++k) {
aPtr = A + k*M;
bPtr = B + k + n*K;
B_v = _mm256_broadcast_sd(bPtr);
A_v = _mm256_load_pd(aPtr);
C_v = _mm256_fmadd_pd(A_v, B_v, C_v);
}
_mm256_store_pd(cPtr, C_v);
}
}
I thought this was going to be much faster, however the performance I get is around 4 GFLOPs, which is identical to running the first code with suboptimal M=4, N=8.
Given that in the first case the matrices are too big to fit entirely in L1d cache, this seems to suggest that the second version of the code is fetching and writing data to L2 cache, even though the matrices should fit in L1d cache.
Is my suspicion correct? If so, is this behaviour caused by some mistake in my thought process, or is there some reason I am missing behind this?
Please give some explanation, rather than just posting a better optimised version of the code, as I would really like to better understand what is going on conceptually. Ideally it would be nice if you guys could give me a few hints on things I could look into myself.
Thanks :)
I tried following some of the tips that were suggested in the comments by @doug and @ElderBug.
@doug in particular suggested to try and transpose B, however given that the matrices are in column-major format I could not find a way to implement their idea. What I did instead was transpose A and accumulate products in a tmp variable.
Here is what I came up with:
void mat_mul(double* A, double* B, double* C) {
__m256d B_v, A_v, rax;
// Transpose A
double AT[K*M];
for(int k=0; k < K; ++k){
for(int m=0; m<M; ++m){
AT[m*K + k] = A[m + k*M];
}
}
// Compute product
for (int n = 0; n < N; ++n) {
for (int m = 0; m < M; ++m) {
double tmp = 0.;
for (int k = 0; k < K; k+=4) {
B_v = _mm256_load_pd(B + n*K + k);
A_v = _mm256_load_pd(AT + m*K + k);
rax = _mm256_mul_pd(A_v, B_v);
double* pv = (double*)&rax;
tmp += pv[0] + pv[1] + pv[2] + pv[3];
}
C[n*M + m] = tmp;
}
}
}
This still runs at around 4 GFLOPs with M=4, N=8. The elephant in the room seems to be reducing the rax vector. I was not able to find a different way to do that more efficiently.
If I remove tmp += pv[0] + pv[1] + pv[2] + pv[3];
the performance doubles, but I reach a peak of 14 GFLOPs with M=N=32, so this is still worse than my naive vectorisation approach.
If anyone has any further suggestion/comments they would be very much appreciated.
I forgot to mention that I am compiling the code using icc (ICC) 2021.5.0 20211109
with the following optimisation flags:
-O3 -funroll-loops -std=c++11 -mavx2 -march=native -fma -ftz -fomit-frame-pointer
The goal is to implement this serial micro-kernel in the best possible way so that I can then re-use it for blocked parallel matrix multiplication. According to my calculations the theoretical peak should be 46 GFLOPS so I am getting about 50%. OpenBLAS gets around 40 GFLOPs so 32-35ish would be nice.
It would be really nice if someone could give some insight on why some of the options I tried do not work so that I can think about how to fix them.
Thanks to @ElderBug's comment I put some extra thought in how the store operations in my first vectorised implementation were being managed and by doing some clever unrolling I was able to get to 40GFLOps which is comparable to OpenBLAS!
Reading this also helped: https://github.com/flame/blis/blob/master/docs/KernelsHowTo.md#gemm-microkernel
Modern CPU are huge mind-boggling super-scalar and out-of-order beasts. Your Zen 2 CPU can have 100+ instructions in-flight. This means that when your first FMA finishes, the CPU might already be 10 loops ahead issuing loads. This is enough to hide L1 and even L2 latency. Of course there are conditions for this to works, for example, there must be no dependency that might prevent computing the load ahead. Branch prediction must also have high prediction rate. Matrix multiplication is a "nice" algorithm, and everything is mostly fine.
Of course that doesn't mean cache should be neglected. Locality is most important, to ensure that you use all the data that is fetched in cache. Reuse is nice to have to reduce bandwidth usage. Predictable load patterns are desirable so that the hardware prefetcher can populate the cache before you even need it.
In your case, cache prefetch is perfect since your inner accesses are sequential. This is probably why it doesn't matter that your matrices fit in L1 or not: most of the accesses will hit since they were predicted. Other accesses might possibly miss, but since there are much fewer of them it matters less (and you can't exclude that they are still prefetched).
TL;DR: while you might still find some small improvements for caching, it is unlikely that it will be a problem no matter how big the matrices are.
What now ? This means the main way to improve performance is to keep the pipeline fed with independent FMAs as much as possible. This means unrolling loops and clever use of registers. This is very boring but that's how it is. Your first version is unrolled into 8 parallels FMA that are only dependent on one broadcast. My first suggestion would be to just take your first version and get the C mat accesses out of the loop, stored in a __m256d[8]
, so that the useless code still in this loop is removed.
One important thing to note is that if you assume L1 is irrelevant, it means you are feeding your FMA off L2. This is not necessarily a problem as explained, but L2 still has less bandwidth than L1. Is L2 bandwidth enough ?
Zen and later CPUs have 32B per cycle L2 bandwidth. You have roughly one 256-bit load per FMA, so this caps the bandwidth consumption of FMA to 32B per FMA. Zen FMA throughput is 1 per cycle, so this fits L2 perfectly. However, with Zen 2 (your CPU), FMA throughput doubled to 2 per cycle, which means L2 cannot feed the FMA at max bandwidth.
So in the end, matrices size do actually matter to get the last half of performance. Bigger matrices will tend toward half the maximum FLOPS. In your case they are probably still small enough that you still get L1 hits, even if they don't all fit.
To fix that you could change your algorithm to guarantee some L1 reuse. Maybe interleave 2 or 4 N columns when iterating over K, reducing M by that much so that it all still fit in registers. Of course, having matrices small enough that they fit in cache is also a solution.