Search code examples
c++raspberry-piarmsimdneon

Accelerating matrix vector multiplication with ARM Neon Intrinsics on Raspberry Pi 4


I need to optimize a matrix vector multiplication. The data looks like following:

  • Vector has 81 columns
  • Matrix has 90,000 rows and 81 columns and is already transposed. So row-wise dot product can be used.
  • The output is hence a vector with 90,000 rows
  • All lie in 1D float array

Some non-function requirements are also have to be met for this routine:

  • As few as possible standard libraries should be used (no std::vector for example)
  • No third-party library should be used (so no Eigen or Blas for me, either)

This is my (simplified, where I assume the input is perfectly blocked, for sake of readability) code,

// input_height = 90000
// input_width = 81

for (uint32_t y = 0; y < input_height; y += 4) {
    float32x4_t sum0 = vmovq_n_f32(0);
    float32x4_t sum1 = vmovq_n_f32(0);
    float32x4_t sum2 = vmovq_n_f32(0);
    float32x4_t sum3 = vmovq_n_f32(0);

    for (uint32_t x = 0; x < input_width; x += 16) {
        float32x4x4_t A = load_matrix_transpose(kernel + x);

        float32x4x4_t B0 = load_matrix_transpose(input + y * input_width + x);
        float32x4x4_t B1 = load_matrix_transpose(input + (y + 1) * input_width + x);
        float32x4x4_t B2 = load_matrix_transpose(input + (y + 2) * input_width + x);
        float32x4x4_t B3 = load_matrix_transpose(input + (y + 3) * input_width + x);

        matrix_element_wise_multiplication(A, B0, sum0);
        matrix_element_wise_multiplication(A, B1, sum1);
        matrix_element_wise_multiplication(A, B2, sum2);
        matrix_element_wise_multiplication(A, B3, sum3);
    }

    output[y] = vaddvq_f32(sum0);
    output[y + 1] = vaddvq_f32(sum1);
    output[y + 2] = vaddvq_f32(sum2);
    output[y + 3] = vaddvq_f32(sum3);
}

Where the load_matrix_transpose, matrix_element_wise_multiplication are the following functions:

inline float32x4x4_t load_matrix_transpose(float *a) {
    float32x4x4_t ret;

    ret.val[0] = simd_load(a);

    ret.val[1] = simd_load(a + 4);

    ret.val[2] = simd_load(a + 8);

    ret.val[3] = simd_load(a + 12);

    return ret;
}

inline void simd_matrix_element_wise_multiplication(float32x4x4_t & A, float32x4x4_t & B, float32x4x4_t & C) {
    C = vmlaq_f32(C, A.val[0], B.val[0]);
    C = vmlaq_f32(C, A.val[1], B.val[1]);
    C = vmlaq_f32(C, A.val[2], B.val[2]);
    C = vmlaq_f32(C, A.val[3], B.val[3]);
}

On my Rasperry Pi 4 (ARMv8, 8GB RAM, 4 cores) the code takes with optimization level -O3 about 60ms.

On long run (many loops), the Neon register version is exactly twice as fast as the normal code.

My question is, is there anyway to optimize the code further? I have tried many things but can not make any improvement with respect to the normal code.


Solution

  • Data locality is the highest priority when it comes to optimizations, and you should be aware of the register capacity since registers are BY FAR the fastest and most scarce resource.

    aarch64: 32x128bit neon registers (512 bytes)
    aarch32: 16x128bit neon registers (256 bytes)

    A 81x90000 matrix when transposed requires to hold 90000 intermediate values to do the multiplication, and since 360000 bytes don't fit into a register bank of 512 bytes, there will be TONS of memory swapping which translates in HUGE performance hits.
    On the other hand, 4*81 bytes of the vector fit nicely into the 512 bytes.

    void matVecMult81x90000(float *pDst, float *pMat, float *pVec)
    {
        register float32x4_t vec0_3, vec4_7, vec8_11, vec12_15, vec16_19, vec20_23, vec24_27, vec28_31, vec32_35, vec36_39, vec40_43, vec44_47, vec48_51, vec52_55, vec56_59, vec60_63, vec64_67, vec68_71, vec72_75, vec76_79, vec80;
        register float32x4_t mat0, mat1, mat2, mat3, mat4, rslt;
        register float32x2_t drslt;
        register uint32_t nRows = 90000;
    
        vec80 = vdupq_n_f32(0.0f);
        mat4 =vdupq_n_f32(0.0f);
        vec0_3 = vld1q_f32(pVec); pVec += 4;
        vec4_7 = vld1q_f32(pVec); pVec += 4;
        vec8_11 = vld1q_f32(pVec); pVec += 4;
        vec12_15 = vld1q_f32(pVec); pVec += 4;
        vec16_19 = vld1q_f32(pVec); pVec += 4;
        vec20_23 = vld1q_f32(pVec); pVec += 4;
        vec24_27 = vld1q_f32(pVec); pVec += 4;
        vec28_31 = vld1q_f32(pVec); pVec += 4;
        vec32_35 = vld1q_f32(pVec); pVec += 4;
        vec36_39 = vld1q_f32(pVec); pVec += 4;
        vec40_43 = vld1q_f32(pVec); pVec += 4;
        vec44_47 = vld1q_f32(pVec); pVec += 4;
        vec48_51 = vld1q_f32(pVec); pVec += 4;
        vec52_55 = vld1q_f32(pVec); pVec += 4;
        vec56_59 = vld1q_f32(pVec); pVec += 4;
        vec60_63 = vld1q_f32(pVec); pVec += 4;
        vec64_67 = vld1q_f32(pVec); pVec += 4;
        vec68_71 = vld1q_f32(pVec); pVec += 4;
        vec72_75 = vld1q_f32(pVec); pVec += 4;
        vec76_79 = vld1q_f32(pVec); pVec += 4;
        vld1q_lane_f32(pVec, vec80, 0);
    
        do {
            mat0 = vld1q_f32(pMat); pMat += 4;
            mat1 = vld1q_f32(pMat); pMat += 4;
            mat2 = vld1q_f32(pMat); pMat += 4;
            mat3 = vld1q_f32(pMat); pMat += 4;
            rslt = vmulq_f32(mat0, vec0_3);
            rslt += vmulq_f32(mat1, vec4_7);
            rslt += vmulq_f32(mat2, vec8_11);
            rslt += vmulq_f32(mat3, vec12_15);
    
            mat0 = vld1q_f32(pMat); pMat += 4;
            mat1 = vld1q_f32(pMat); pMat += 4;
            mat2 = vld1q_f32(pMat); pMat += 4;
            mat3 = vld1q_f32(pMat); pMat += 4;
            rslt += vmulq_f32(mat0, vec16_19);
            rslt += vmulq_f32(mat1, vec20_23);
            rslt += vmulq_f32(mat2, vec24_27);
            rslt += vmulq_f32(mat3, vec28_31);
    
            mat0 = vld1q_f32(pMat); pMat += 4;
            mat1 = vld1q_f32(pMat); pMat += 4;
            mat2 = vld1q_f32(pMat); pMat += 4;
            mat3 = vld1q_f32(pMat); pMat += 4;
            rslt += vmulq_f32(mat0, vec32_35);
            rslt += vmulq_f32(mat1, vec36_39);
            rslt += vmulq_f32(mat2, vec40_43);
            rslt += vmulq_f32(mat3, vec44_47);
    
            mat0 = vld1q_f32(pMat); pMat += 4;
            mat1 = vld1q_f32(pMat); pMat += 4;
            mat2 = vld1q_f32(pMat); pMat += 4;
            mat3 = vld1q_f32(pMat); pMat += 4;
            rslt += vmulq_f32(mat0, vec48_51);
            rslt += vmulq_f32(mat1, vec52_55);
            rslt += vmulq_f32(mat2, vec56_59);
            rslt += vmulq_f32(mat3, vec60_63);
    
            mat0 = vld1q_f32(pMat); pMat += 4;
            mat1 = vld1q_f32(pMat); pMat += 4;
            mat2 = vld1q_f32(pMat); pMat += 4;
            mat3 = vld1q_f32(pMat); pMat += 4;
            vld1q_lane_f32(pMat, mat4, 0); pMat += 1;
            rslt += vmulq_f32(mat0, vec64_67);
            rslt += vmulq_f32(mat1, vec68_71);
            rslt += vmulq_f32(mat2, vec72_75);
            rslt += vmulq_f32(mat3, vec76_79);
            rslt += vmulq_f32(mat4, vec80);
    
            *pDst++ = vaddvq_f32(rslt);
        } while (--nRows);
    }
    

    Unfortunately, compilers don't play along nicely. (Both GCC and Clang)
    The generated code shows some stack swapping on the Vector inside the loop. Below is the same function in hand written assembly without any stack swapping:

        .arch   armv8-a
        .global     matVecMult81x90000_asm
        .text
    
    .balign 64
    .func
    matVecMult81x90000_asm:
    // init loop counter
        mov     w3, #90000 & 0xffff
        movk    w3, #90000>>16, lsl #16
    
    // preserve registers
        stp     d8, d9, [sp, #-48]!
        stp     d10, d11, [sp, #1*16]
        stp     d12, d13, [sp, #2*16]
    
    // load vectors
        ldp     q0, q1, [x2, #0*32]
        ldp     q2, q3, [x2, #1*32]
        ldp     q4, q5, [x2, #2*32]
        ldp     q6, q7, [x2, #3*32]
        ldp     q8, q9, [x2, #4*32]
        ldp     q10, q11, [x2, #5*32]
        ldp     q12, q13, [x2, #6*32]
        ldp     q16, q17, [x2, #7*32]
        ldp     q18, q19, [x2, #8*32]
        ldp     q20, q21, [x2, #9*32]
        ldr     s22, [x2, #10*32]
    
    // loop
    .balign 64
    1:
        ldp     q24, q25, [x1, #0*32]
        ldp     q26, q27, [x1, #1*32]
        ldp     q28, q29, [x1, #2*32]
        ldp     q30, q31, [x1, #3*32]
        subs    w3, w3, #1
    
        fmul    v23.4s, v24.4s, v0.4s
        fmla    v23.4s, v25.4s, v1.4s
        fmla    v23.4s, v26.4s, v2.4s
        fmla    v23.4s, v27.4s, v3.4s
        fmla    v23.4s, v28.4s, v4.4s
        fmla    v23.4s, v29.4s, v5.4s
        fmla    v23.4s, v30.4s, v6.4s
        fmla    v23.4s, v31.4s, v7.4s
    
        ldp     q24, q25, [x1, #4*32]
        ldp     q26, q27, [x1, #5*32]
        ldp     q28, q29, [x1, #6*32]
        ldp     q30, q31, [x1, #7*32]
    
        fmla    v23.4s, v24.4s, v8.4s
        fmla    v23.4s, v25.4s, v9.4s
        fmla    v23.4s, v26.4s, v10.4s
        fmla    v23.4s, v27.4s, v11.4s
        fmla    v23.4s, v28.4s, v12.4s
        fmla    v23.4s, v29.4s, v13.4s
        fmla    v23.4s, v30.4s, v16.4s
        fmla    v23.4s, v31.4s, v17.4s
    
        ldp     q24, q25, [x1, #8*32]
        ldp     q26, q27, [x1, #9*32]
        ldr     s28, [x1, #10*32]
    
        fmla    v23.4s, v24.4s, v18.4s
        fmla    v23.4s, v25.4s, v19.4s
        fmla    v23.4s, v26.4s, v20.4s
        fmla    v23.4s, v27.4s, v21.4s
        fmla    v23.4s, v28.4s, v22.4s
    
        add     x1, x1, #81*4
    
        faddp   v23.4s, v23.4s, v23.4s
        faddp   v23.2s, v23.2s, v23.2s
    
        str     s23, [x0], #4
        b.ne    1b
    
    .balign 8
    //restore registers
    
        ldp     d10, d11, [sp, #1*16]
        ldp     d12, d13, [sp, #2*16]
        ldp     d8, d9, [sp], #48
    
    // return
        ret
    
    .endfunc
    .end
    

    Test results on RK3368:
    Clang intrinsics: 10.41ms
    assembly: 9.59ms

    The compilers didn't perform that bad in this case, but more than often they are unbelievably stupid. I strongly recommend learning assembly.