Search code examples
c++armintrinsicsneon

How to clear all but the first non-zero lane in neon?


I have a mask in a uint32x4_t neon register. In this mask at least 1 of the 4 ints is set (e.g. 0xffffffff), however, I may have a case where there are more than one items set in the register. How can I ensure that only one is set?

in c pseudo code:

uint32x4_t clearmask(uint32x4_t m)
{
         if (m[0]) { m[1] = m[2] = m[3] = 0; }
    else if (m[1]) { m[2] = m[3] = 0; }
    else if (m[2]) { m[3] = 0; }
    return m;
}

Basically I want to clear all but one of the set lanes. Obvious straightforward implementation in neon could be:

uint32x4_t cleanmask(uint32x4_t m)
{
    uint32x4_t mx;
    mx = vdupq_lane_u32(vget_low_u32(vmvnq_u32(m)), 0);
    mx = vsetq_lane_u32(0xffffffff, mx, 0);
    m = vandq_u32(m, mx);

    mx = vdupq_lane_u32(vget_low_u32(vmvnq_u32(m)), 1);
    mx = vsetq_lane_u32(0xffffffff, mx, 1);
    m = vandq_u32(m, mx);

    mx = vdupq_lane_u32(vget_high_u32(vmvnq_u32(m)), 0);
    mx = vsetq_lane_u32(0xffffffff, mx, 2);
    m = vandq_u32(m, mx);

    return m;
}

How can this be done more efficiently in arm neon?


Solution

  • Very simple:

    vceq.u32    q1, q0, #0
    vmov.i8     d7, #0xff
    vext.8      q2, q3, q1, #12
    
    vand        q0, q0, q2
    vand        d1, d1, d2
    vand        d1, d1, d4
    

    6 instructions total, 5 if you can keep q3 as a constant.

    The aarch64 version below must be easier to understand:

    cmeq    v1.4s, v0.4s, #0
    movi    v31.16b, #0xff
    
    ext     v2.16b, v31.16b, v1.16b, #12
    ext     v3.16b, v31.16b, v1.16b, #8
    ext     v4.16b, v31.16b, v1.16b, #4
    
    and     v0.16b, v0.16b, v2.16b
    and     v0.16b, v0.16b, v3.16b
    and     v0.16b, v0.16b, v4.16b
    

    How this works

    ext/vext takes a window from the concatenation of two vectors, so we're creating masks

    v0 = [  d   c   b   a ]
    
    v2 = [ !c  !b  !a  -1 ]
    v3 = [ !b  !a  -1  -1 ]
    v4 = [ !a  -1  -1  -1 ]
    

    The highest element (d) is zeroed if any of the previous elements are non-zero.

    The 2nd highest element (c) is zeroed if any of its preceding elements (a or b) are non-zero. And so on.


    With elements guaranteed to 0 or -1, mvn also works instead of a compare against zero.