Search code examples
performanceintelmatrix-multiplicationavxavx512

Efficient way for using int8 AVX512-VNNI instruction, especially about loading the data to zmm register


I want to optimize my matrix multiplication operation using AVX512-VNNI instruction in int8 data.

I understand how this vpdpbusd works, but I don't know how to use this efficiently.

In detail, I partition the big matrices (MxN) into smaller matrices and the size of "smaller matrix" is 4x4, which one vpdpbusd can process with only one instruction.

You should also notice that my matrix multiplication algorithm is based on the algorithm used in "GotoBLAS".

By this way, I think I can do the whole matrix multiplication with multiple 4x4 X 4x4 matrix multiplication (some sort of "divide and conquer").

In detail,

  1. load one 4x4 matrix (let's say A) four times to one zmm register
  2. load four different 4x4 matrices (let's say B, C, D, E, resp.) to another zmm register.
  3. vpdpbusd do four multiplications at a time : AxB, AxC, AxD, AxE.
  4. store the results back.

But the problem is loading the data to zmm register. 4x4 matrix is pretty complicated to load on a long zmm register.

Also, vpdpbusd results in 32-bit data. So I need to convert it back to 8-bit data. I think this might also incur another overhead.

I don't know how to deal with this problem... Help!


Solution

  • If your matrices are indeed big (big enough), repacking tiles (also just called "packing", same thing) is worth doing for TLB considerations (tiling without repacking results in tiles that span many different pages and may cost a lot of TLB misses to access). When you're repacking anyway, you can choose any reordering of the elements without paying much extra for it, it's not significantly more expensive than repacking already is by itself. There is a cost to repacking of course, but the packed tiles are reused so this cost is amortized over several uses. Here's a way to arrange the computation:

    • Broadcast-load 4 chunks of 4 elements (a tiny row) from matrix A (4x vpbroadcastd, ensure that it is the variant with a memory operand, not a separate load+broadcast, nor 4 loads coalesced into one and then 4 shuffles - this is not critical since the throughput of vpdpbusds is not that high and vpdpbusds goes to p0 while shuffles go to p5, but it's silly to pay unnecessary shuffle µops and I think it's not unlikely that it will matter on future CPUs).
    • Load 4 chunks from a repacked tile of matrix B, where the i'th dword in every chunk is a tiny column of 4 elements from the i'th column. The 4 chunks together span 16 rows. Sounds bad, but there's no complicated shuffle here, that happened during repacking.
    • Form all 16 products, summing them into 16 independent accumulators (this is overkill for the current latency-throughput product of vpdpbusds but it can probably only go up). In each DWORD, vpdpbusds computes the dot product between a tiny row from A and a tiny column from B. The 16 products computed that way all reuse the same tiny row from A, but different columns from B.

    The results don't form 4x4 sub-matrices of the output matrix, but 16-wide rows that can be summed together and then into the matrix (also narrowed to 8-bit in your case, but as you can see that happens after the inner loop, so the impact is minor if any).

    If your matrices are so small that repacking costs more overhead than it repays, then I'm not sure what is a good way to arrange the computation.