Search code examples
armneon

Find min and position of the min element in uint8x8_t neon register


consider this piece of code:

uint8_t v[8] = { ... };
int ret = 256;
int ret_pos = -1;
for (int i=0; i<8; ++i)
{
    if (v[i] < ret)
    {
        ret = v[i];
        ret_pos = i;
    }
}

It finds min and position of the min element (ret and ret_pos). In arm neon I could use pairwise min to find min element in v, but how do I find position of the min element?

Update: see my own answer, what would you suggest to improve it?


Solution

  • Here's how I've done after spending some time fiddling with bits and math:

    #define VMIN8(x, index, value)                               \
    do {                                                         \
        uint8x8_t m = vpmin_u8(x, x);                            \
        m = vpmin_u8(m, m);                                      \
        m = vpmin_u8(m, m);                                      \
        uint8x8_t r = vceq_u8(x, m);                             \
                                                                 \
        uint8x8_t z = vand_u8(vmask, r);                         \
                                                                 \
        z = vpadd_u8(z, z);                                      \
        z = vpadd_u8(z, z);                                      \
        z = vpadd_u8(z, z);                                      \
                                                                 \
        unsigned u32 = vget_lane_u32(vreinterpret_u32_u8(z), 0); \
        index = __lzcnt(u32);                                    \
        value = vget_lane_u8(m, 0);                              \
    } while (0)
    
    
    uint8_t v[8] = { ... };
    
    static const uint8_t mask[] = { 0x80, 0x40, 0x20, 0x10, 0x08, 0x04, 0x02, 0x01 };
    uint8x8_t vmask = vld1_u8(mask);
    
    uint8x8_t v8 = vld1_u8(v);
    int ret;
    int ret_pos;
    VMIN8(v8, ret_pos, ret);
    

    where __lzcnt is clz (__builtin_clz in gcc).

    Here's the how it works. At first using pairwise min set all u8 fields of uint8x8_t to the minimum value:

        uint8x8_t m = vpmin_u8(x, x);
        m = vpmin_u8(m, m);
        m = vpmin_u8(m, m);
    

    then using vector compare set min element to all ones, and all others set to zeros:

        uint8x8_t r = vceq_u8(x, m);
    

    Then perform logical AND with the mask that contains values: uint8_t mask[] {1<<7, 1<<6, 1<<5, ... 1<<1, 1<<0 };:

    uint8x8_t z = vand_u8(vmask, r);
    

    and after that using pairwise add add all 8 bytes of

    z = vpadd_u8(z, z);
    z = vpadd_u8(z, z);
    z = vpadd_u8(z, z);
    

    and after that using clz calculate position of the first min element.

    unsigned u32 = vget_lane_u32(vreinterpret_u32_u8(z), 0);
    index = __lzcnt(u32);
    

    Then, in real code I use VMIN8 multiple times per loop iteration and compiler is able to perfectly interleave multiple VMIN8 calls to avoid data stalls.