Search code examples
armneon

ARM NEON: Convert a binary 8-bit-per-pixel image (only 0/1) to 1-bit-per-pixel?


I am working on a task to convert a large binary label image, which has 8 bits (uint8_t) per pixel and each pixel can only be 0 or 1 (or 255), to an array of uint64_t numbers and each bit in uint64_t number represent a label pixel.

For example,

input array: 0 1 1 0 ... (00000000 00000001 00000001 00000000 ...)

or input array: 0 255 255 0 ... (00000000 11111111 11111111 00000000 ...)

output array (number): 6 (because after convert each uint8_t to bit, it becomes 0110)

Currently the C code to achieve this is:

 for (int j = 0; j < width >> 6; j++) {
        uint8_t* in_ptr= in + (j << 6);
        uint64_t out_bits = 0;
        if (in_ptr[0]) out_bits |= 0x0000000000000001;
        if (in_ptr[1]) out_bits |= 0x0000000000000002;
        .
        .
        .
        if (in_ptr[63]) out_bits |= 0x8000000000000000;
       *output = obits; output ++;
    }

Can ARM NEON optimize this functionality? Please help. Thank you!


Solution

  • Assuming the input value is either 0 or 255, below is the basic version which is rather straightforward, especially for people with Intel SSE/AVX experience.

    void foo_basic(uint8_t *pDst, uint8_t *pSrc, intptr_t length)
    {
        //assert(length >= 64);
        //assert(length & 7 == 0);
        uint8x16_t in0, in1, in2, in3;
        uint8x8_t out;
        const uint8x16_t mask = {1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128};
    
        length -= 64;
    
        do {
            do {
                in0 = vld1q_u8(pSrc); pSrc += 16;
                in1 = vld1q_u8(pSrc); pSrc += 16;
                in2 = vld1q_u8(pSrc); pSrc += 16;
                in3 = vld1q_u8(pSrc); pSrc += 16;
    
                in0 &= mask;
                in1 &= mask;
                in2 &= mask;
                in3 &= mask;
    
                in0 = vpaddq_u8(in0, in1);
                in2 = vpaddq_u8(in2, in3);
    
                in0 = vpaddq_u8(in0, in2);
    
                out = vpadd_u8(vget_low_u8(in0), vget_high_u8(in0));
    
                vst1_u8(pDst, out); pDst += 8;
    
                length -= 64;
            } while (length >=0);
    
            pSrc += length>>3;
            pDst += length;
        } while (length > -64);
    }
    

    Neon however has VERY user friendly and efficient permutation and bit operation instructions that allow to go "vertical"

    void foo_advanced(uint8_t *pDst, uint8_t *pSrc, intptr_t length)
    {
        //assert(length >= 128);
        //assert(length & 7 == 0);
        uint8x16x4_t in0, in1;
        uint8x16x2_t row04, row15, row26, row37;
    
        length -= 128;
    
        do {
            do {
                in0 = vld4q_u8(pSrc); pSrc += 64;
                in1 = vld4q_u8(pSrc); pSrc += 64;
    
                row04 = vuzpq_u8(in0.val[0], in1.val[0]);
                row15 = vuzpq_u8(in0.val[1], in1.val[1]);
                row26 = vuzpq_u8(in0.val[2], in1.val[2]);
                row37 = vuzpq_u8(in0.val[3], in1.val[3]);
    
                row04.val[0] = vsliq_n_u8(row04.val[0], row15.val[0], 1);
                row26.val[0] = vsliq_n_u8(row26.val[0], row37.val[0], 1);
                row04.val[1] = vsliq_n_u8(row04.val[1], row15.val[1], 1);
                row26.val[1] = vsliq_n_u8(row26.val[1], row37.val[1], 1);
    
                row04.val[0] = vsliq_n_u8(row04.val[0], row26.val[0], 2);
                row04.val[1] = vsliq_n_u8(row04.val[1], row26.val[1], 2);
    
                row04.val[0] = vsliq_n_u8(row04.val[0], row04.val[1], 4);
    
                vst1q_u8(pDst, row04.val[0]); pDst += 16;
    
                length -= 128;
            } while (length >=0);
    
            pSrc += length>>3;
            pDst += length;
        } while (length > -128);
    }
    

    The Neon-only advanced version is shorter and faster, but GCC is extremely bad at dealing with Neon specific permutation instructions such as vtrn, vzip, and vuzp.

    https://godbolt.org/z/bGdbohqKe

    Clang isn't any better: it spams unnecessary vorr where GCC does the same with vmov.

        .syntax unified
        .arm
        .arch   armv7-a
        .fpu    neon
        .global foo_asm
        .text
    
    .func
    .balign 64
    foo_asm:
        sub     r2, r2, #128
    
    .balign 16
    1:
        vld4.8      {d16, d18, d20, d22}, [r1]!
        vld4.8      {d17, d19, d21, d23}, [r1]!
        vld4.8      {d24, d26, d28, d30}, [r1]!
        vld4.8      {d25, d27, d29, d31}, [r1]!
        subs    r2, r2, #128
    
        vuzp.8      q8, q12
        vuzp.8      q9, q13
        vuzp.8      q10, q14
        vuzp.8      q11, q15
    
        vsli.8      q8, q9, #1
        vsli.8      q10, q11, #1
        vsli.8      q12, q13, #1
        vsli.8      q14, q15, #1
    
        vsli.8      q8, q10, #2
        vsli.8      q12, q14, #2
    
        vsli.8      q8, q12, #4
    
        vst1.8      {q8}, [r0]!
        bpl     1b
    
        add     r1, r1, r2
        cmp     r2, #-128
        add     r0, r0, r2, asr #3
    
        bgt     1b
    .balign 8
        bx      lr
    
    .endfunc
    .end
    

    The inner most loop consists of :
    GCC: 32 instructions
    Clang: 30 instructions
    Asm: 18 instructions

    It doesn't take rocket science to figure out which one is the fastest and by how much: Never trust compilers if you are about to do permutations.