Search code examples
c++matrixneonavx512

How to Load and Store data for the new AVX-VNNI and Arm Neon MMLA instructions efficiently?


What is the appropriate way to load data for the recent AVX-VNNI and Arm Neon MMLA instructions?

For example, the description of SMMLA is:

Signed 8-bit integer matrix multiply-accumulate. This instruction multiplies the 2x8 matrix of signed 8-bit integer values in the first source vector by the 8x2 matrix of signed 8-bit integer values in the second source vector. The resulting 2x2 32-bit integer matrix [...]

Similarly, the description for _mm256_dpbusd_epi32 is:

Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding signed 8-bit integers in b, producing 4 intermediate signed 16-bit results. Sum these 4 results with the corresponding 32-bit integer in src, and store the packed 32-bit results in dst.

It seems that they all require inputs of the form 2[4]x8 and 8x[4]2. and produce outputs of the form 2[4]x[4]2. How can I efficiently load and store data for these functions?

I see three broad possibilities to use these instructions, none appealing:

  • [Split and Combine] I load two consecutive 128-bit vectors and then split them. Similarly, for AVX, I would load 4 128 or 256 vectors and then split them. Storing is equally "complicated" since I need to extract the relevant parts of the 2[4]x[4]2 matrix before storing it. My code is cluttered with splitting/merging instructions.
  • [Smaller Vectors] Alternatively, I could load smaller portions, but that seems inefficient too.
  • [Reorder Input Data] Of course, I could reorder the input data so that the vectorized loads already span multiple rows or columns. Should that be the intended use?

An example code for the inner loop (reduction over K) of a small 4xK input matrix A (row-major) and a Kx4 matrix B (column-major) is as follows:

for (size_t k = 0; k < 64; k += 8) {
    uint8x8_t low = vld1_u8(row0);
    uint8x8_t high = vld1_u8(row1);
    uint8x16_t row01x01234567 = vcombine_u8(low, high);
    row0 += 8;
    row1 += 8;
    low = vld1_u8(row2);
    high = vld1_u8(row3);
    uint8x16_t row23x01234567 = vcombine_u8(low, high);
    row2 += 8;
    row3 += 8;
    low = vld1_u8(col0);
    high = vld1_u8(col1);
    uint8x16_t col01x01234567 = vcombine_u8(low, high);
    col0 += 8;
    col1 += 8;
    low = vld1_u8(col2);
    high = vld1_u8(col3);
    uint8x16_t col23x01234567 = vcombine_u8(low, high);
    col2 += 8;
    col3 += 8;
    out01x01 = vmmlaq_u32(out01x01, row01x01234567, col01x01234567);
    out01x23 = vmmlaq_u32(out01x23, row01x01234567, col23x01234567);

    out23x01 = vmmlaq_u32(out23x01, row23x01234567, col01x01234567);
    out23x23 = vmmlaq_u32(out23x23, row23x01234567, col23x01234567);
}


The result is correct, but seems terribly inefficient. The code above is just an example. I actually would use larger tile sizes to maximize register usage.


Solution

  • Packing the matrices A and B is indeed necessary.

    For an short outline, consider the PowerPC documentation (red book). https://www.redbooks.ibm.com/abstracts/redp5612.html (page 35). PowerPC has a similar blocked matrix multiply instruction as VNNI and Arm Neon.

    I have written such packing function within the matrix multiply code. The packing didn't do any harm to the throughput of the code.