Search code examples
c++armsimdneon

NEON vectorize sum of products of unsigned bytes: (a[i]-int1) * (b[i]-int2)


I need to improve a loop, because is called by my application thousands of times. I suppose I need to do it with Neon, but I don´t know where to begin.

Assumptions / pre-conditions:

  • w is always 320 (multiple of 16/32).
  • pa and pb are 16-byte aligned
  • ma and mb are positive.

 int whileInstruction (const unsigned char *pa,const unsigned char *pb,int ma,int mb,int w)
{
    int sum=0;

    do {
        sum += ((*pa++)-ma)*((*pb++)-mb);

    } while(--w);


    return sum;
}

This attempt at vectorizing it is not working well, and isn't safe (missing clobbers), but demonstrates what I'm trying to do:

int whileInstruction (const unsigned char *pa,const unsigned char *pb,int ma,int mb,int w)
{

    asm volatile("lsr          %2, %2, #3      \n"
                 ".loop:                       \n"
                 "# load 8 elements:             \n"
                 "vld4.8      {d0-d3}, [%1]!   \n"
                 "vld4.8      {d4-d7}, [%2]!   \n"
                 "# do the operation:     \n"
                 "vaddl.u8    q7, d0, r7       \n"
                 "vaddl.u8    q8, d1, d8       \n"
                 "vmlal.u8    q7, q7, q8       \n"
                 "# Sum the vector a save in sum (this is wrong):\n"
                 "vaddl.u8    q7, d0, r7       \n"
                 "subs        %2, %2, #1       \n" // Decrement iteration count
                 "bne         .loop            \n" // Repeat unil iteration count is not zero
                 :
                 : "r"(pa), "r"(pb), "r"(w),"r"(ma),"r"(mb),"r"(sum)
                 : "r4", "r5", "r6","r7","r8","r9"
                 );

    return sum;
}

Solution

  • Here is a simple NEON implementation. I have tested this against the scalar code to make sure that it works. Note that for best performance both pa and pb should be 16 byte aligned.

    #include <arm_neon.h>
    
    int whileInstruction_neon(const unsigned char *pa, const unsigned char *pb, int ma, int mb, int w)
    {
        int sum = 0;
    
        const int32x4_t vma = { ma, ma, ma, ma };
        const int32x4_t vmb = { mb, mb, mb, mb };
    
        int32x4_t vsumll = { 0 };
        int32x4_t vsumlh = { 0 };
        int32x4_t vsumhl = { 0 };
        int32x4_t vsumhh = { 0 };
        int32x4_t vsum;
    
        int i;
    
        for (i = 0; i <= (w - 16); i += 16)
        {
            uint8x16_t va = vld1q_u8(pa);   // load vector from pa
            uint8x16_t vb = vld1q_u8(pb);   // load vector from pb
    
            // unpack va into 4 vectors
    
            int16x8_t val =  (int16x8_t)vmovl_u8(vget_low_u8(va));
            int16x8_t vah =  (int16x8_t)vmovl_u8(vget_high_u8(va));
            int32x4_t vall = vmovl_s16(vget_low_s16(val));
            int32x4_t valh = vmovl_s16(vget_high_s16(val));
            int32x4_t vahl = vmovl_s16(vget_low_s16(vah));
            int32x4_t vahh = vmovl_s16(vget_high_s16(vah));
    
            // subtract means
    
            vall = vsubq_s32(vall, vma);
            valh = vsubq_s32(valh, vma);
            vahl = vsubq_s32(vahl, vma);
            vahh = vsubq_s32(vahh, vma);
    
            // unpack vb into 4 vectors
    
            int16x8_t vbl =  (int16x8_t)vmovl_u8(vget_low_u8(vb));
            int16x8_t vbh =  (int16x8_t)vmovl_u8(vget_high_u8(vb));
            int32x4_t vbll = vmovl_s16(vget_low_s16(vbl));
            int32x4_t vblh = vmovl_s16(vget_high_s16(vbl));
            int32x4_t vbhl = vmovl_s16(vget_low_s16(vbh));
            int32x4_t vbhh = vmovl_s16(vget_high_s16(vbh));
    
            // subtract means
    
            vbll = vsubq_s32(vbll, vmb);
            vblh = vsubq_s32(vblh, vmb);
            vbhl = vsubq_s32(vbhl, vmb);
            vbhh = vsubq_s32(vbhh, vmb);
    
            // update 4 partial sum of products vectors
    
            vsumll = vmlaq_s32(vsumll, vall, vbll);
            vsumlh = vmlaq_s32(vsumlh, valh, vblh);
            vsumhl = vmlaq_s32(vsumhl, vahl, vbhl);
            vsumhh = vmlaq_s32(vsumhh, vahh, vbhh);
    
            pa += 16;
            pb += 16;
        }
    
        // sum 4 partial sum of product vectors
    
        vsum = vaddq_s32(vsumll, vsumlh);
        vsum = vaddq_s32(vsum, vsumhl);
        vsum = vaddq_s32(vsum, vsumhh);
    
        // do scalar horizontal sum across final vector
    
        sum = vgetq_lane_s32(vsum, 0);
        sum += vgetq_lane_s32(vsum, 1);
        sum += vgetq_lane_s32(vsum, 2);
        sum += vgetq_lane_s32(vsum, 3);
    
        // handle any residual non-multiple of 16 points
    
        for ( ; i < w; ++i)
        {
            sum +=  (*pa++ - ma) * (*pb++ - mb);
        }
    
        return sum;
    }