Search code examples
c++armintrinsicsneon

Efficiently accumulate sign bits in arm neon


I have a loop that does some computations and then stores sign bits into a vector:

uint16x8_t rotate(const uint16_t* x);

void compute(const uint16_t* src, uint16_t* dst)
{
    uint16x8_t sign0 = vmovq_n_u16(0);
    uint16x8_t sign1 = vmovq_n_u16(0);
    for (int i=0; i<16; ++i)
    {
        uint16x8_t r0 = rotate(src++);
        uint16x8_t r1 = rotate(src++);
        // pseudo code:
        sign0 |= (r0 >> 15) << i;
        sign1 |= (r1 >> 15) << i;
    }
    vst1q_u16(dst+1, sign0);
    vst1q_u16(dst+8, sign1);
}

What's the best way to accumulate sign bits in neon that follows that pseudo code?

Here's what I came up with:

    r0 = vshrq_n_u16(r0, 15);
    r1 = vshrq_n_u16(r1, 15);
    sign0 = vsraq_n_u16(vshlq_n_u16(r0, 15), sign0, 1);
    sign1 = vsraq_n_u16(vshlq_n_u16(r1, 15), sign1, 1);

Also, note that the "pseudo code" actually works and generates pretty much the same code perf wise. What can be improved here? Note, in actual code there is no function calls in the loop, I trimmed down actual code to make it simple to understand. Another point: in neon you cannot use a variable for vector shift (e.g. i cannot use used to specify number of shifts).


Solution

  • ARM can do this in one vsri instruction (thanks @Jake'Alquimista'LEE).

    Given a new vector where that you want sign bits from, replace the low 15 bits of each element with the accumulator right-shifted by 1.

    You should unroll by 2 so the compiler doesn't need a mov instruction to copy the result back into the same register, because vsri is a 2-operand instruction, and the way we need to use it here gives us the result in a different register than the old sign0 accumulator.

    sign0 =  vsriq_n_u16(r0, sign0, 1);
    // insert already-accumulated bits below the new bit we want
    

    After 15 inserts, (or 16 if you start with sign0 = 0 instead of peeling the first iteration and using sign0=r0), all 16 bits (per element) of sign0 will be sign bits from r0 values.


    Previous suggestion: AND with a vector constant to isolate the sign bit. It's more efficient than two shifts.

    Your idea of accumulating with VSRA to shift the accumulator and add in the new bit is good, so we can keep that and get down to 2 instructions total.

    tmp = r0 & 0x8000;            // VAND
    sign0 = (sign0 >> 1) + tmp;   // VSRA
    

    or using neon intrinsics:

    uint16x8_t mask80 = vmovq_n_u16(0x8000);
    r0 = vandq_u16(r0, mask80);        // VAND
    sign0 = vsraq_n_u16(r0, sign0, 1); // VSRA
    

    Implement with intrinsics or asm however you like, and write the scalar version the same way to give the compiler a better chance to auto-vectorize.


    This does need a vector constant in a register. If you're very tight on registers, then 2 shifts could be better, but 3 shifts total seems likely to bottleneck on shifter throughput unless ARM chips typically spend a lot of real-estate on SIMD barrel shifters.

    In that case, maybe use this generic SIMD idea without ARM shift+accumulate or shift+insert

    tmp = r0 >> 15;     // logical right shift
    sign0 += sign0;     // add instead of left shifting
    sign0 |= tmp;       // or add or xor or whatever.
    

    This gives you the bits in the opposite order. If you can produce them in the opposite order, then great.

    Otherwise, does ARM have SIMD bit-reverse or only for scalar? (Generate in reverse order and flip them at the end, with some extra work for every vector bitmap, hopefully only one instruction.)

    Update: yes, AArch64 has rbit, so you could reverse bits within a byte, then byte-shuffle to put them in the right order. x86 could use a pshufb LUT to bit-reverse within bytes in two 4-bit chunks. This might not come out ahead of doing more work as you accumulate the bits on x86, though.