Search code examples
cassemblyarmarm64neon

How do I cast a vector to a float64_t to check a SIMD compare for all-zero?


I'm working with ARM NEON using the intrinsics published by ARM. I would like my code to work both on ARMv7 and AArch64. One idiom I use to check if a vector is all zero is this assembly idiom:

shrn v0.8b, v0.8h, #4    // compress bytes in v0 into nibbles
fcmp d0, #0.0

This works correctly if all bytes of v0 are either 0x00 or 0xff, which is the case in my code.

To translate this idiom into C code, I tried to use code like this, which succesfully builds on AArch64 and yields the desired machine code:

static int
veq_zero(uint8x16_t v)
{
    uint8x8_t narrowed;

    /* narrow each byte to a nibble */
    narrowed = vshrn_n_u16(vreinterpretq_u16_u8(v), 4);

    /* check if that vector is all zero */
    return (vdupd_lane_f64(vreinterpret_f64_u16(narrowed), 0) == 0.0);
}

Unfortunately neither the vdupd_lane_f64 nor the vreinterpret_f64_u16 call are available on ARMv7; it seems that indeed the entire float64x1_t type and all its related functions are absent.

What's the best way to cast this program idiom into C with intrinsics such that it compiles both for ARMv7 with NEON and for AArch64?


Solution

  • To work around this short coming, I ended up type-punning through a union:

    static int
    veq_zero(uint8x16_t v)
    {
    #ifdef __arm__
        union { uint8x8_t v; double d; } narrowed;
    
        /* narrow each byte to a nibble */
        narrowed.v = vshrn_n_u16(vreinterpretq_u16_u8(v), 4);
    
        /* check if that vector is all zero */
        return (narrowed.d == 0.0);
    #else /* AArch64 */
        uint8x8_t narrowed;
    
        /* narrow each byte to a nibble */
        narrowed = vshrn_n_u16(vreinterpretq_u16_u8(v), 4);
    
        /* check if that vector is all zero */
        return (vdupd_lane_f64(vreinterpret_f64_u16(narrowed), 0) == 0.0);
    #endif
    }
    

    This is legal in C; for C++, you'd want memcpy or C++20 std::bit_cast<double>(narrowed_v) on uint8x8_t narrowed_v.
    Also, MSVC and GCC/Clang define the behaviour of union type-punning as an extension to C++, so the above code is also safe there.

    Not happy with this solution, I hope there is a nicer one.