Search code examples
c++bit-manipulationsimdarm64neon

Reducing NEON vector with variable amounts of bits in each element into a single 32-bit value (concatenate variable-length bitfields)


I have the output of some computation which is in the form of two NEON uint8x16_t SIMD registers, one which contains some significant information in the lower N bits of each element, and a second register in which each element contains the value of N for each corresponding element in the other register. The remaining 8-N bits in the first register are zeroed, and the sum of N across the entire vector can also be guaranteed to be <= 32.

What I need to do is to take these bits and concatenate them into a single register. I can make changes to the code producing these results and have them store the bits in the upper N bits instead of the lower N, and I can also reverse the order of the elements in the register.

The only hard restriction that can't be potentially worked around here is that there cannot be any padding bits between the concatenated bits, and order of concatenation must be the same for any values in either register.

The naive solution to this is of course to simply iterate over each element and accumulate:

uint32_t bit_reduce_naive(uint8x16_t bits, uint8x16_t bit_counts)
{
    uint32_t result = 0;
    for (size_t i = 0; i != 16; ++i)
    {
        result <<= bit_counts[i];
        result |= bits[i];
    }
    return result;
}

After this, my next attempt was to try a more SIMD approach and iteratively shift the upper bits by the count of the lower bits, and then or the upper and lower halves together to merge them into a new vector with half the elements (this is the u8 -> u16 step, the others are the same but with wider element sizes):

inline std::pair<uint16x8_t, uint16x8_t> reduce_step_u8(uint8x16_t bits, uint8x16_t bit_counts)
{
    const uint16x8_t expanded_count_lower = vmovl_u8(vget_low_u8(bit_counts));
    const uint16x8_t expanded_bits_lower  = vmovl_u8(vget_low_u8(bits));
    const uint16x8_t expanded_bits_higher = vmovl_high_u8(bits);
    return std::pair
    {
        vorrq_u16(expanded_bits_lower, vshlq_u16(expanded_bits_higher, expanded_count_lower)),
        vaddl_u8( vget_low_u8(bit_counts), vget_high_u8(bit_counts)),
    };
}

These can then be chained together to reduce the input to two uint64_t values which can be delt with in normal integer registers and returned:

uint32_t bit_reduce_better(uint8x16_t bits, uint8x16_t bit_counts)
{
    const auto [bits_u16, counts_u16] = reduce_step_u8(bits, bit_counts);
    const auto [bits_u32, counts_u32] = reduce_step_u16(bits_u16, counts_u16);
    const auto [bits_u64, counts_u64] = reduce_step_u32(bits_u32, counts_u32);

    const uint64_t low_bits  = vget_low_u64(bits_u64)[0];
    const uint64_t high_bits = vget_low_u64(bits_u64)[0];
    const uint64_t low_bits_count = vget_high_u64(counts_u64)[0];

    return static_cast<uint32_t>(low_bits | (high_bits << low_bits_count));
}

Note that this solution doesn't concatenate the bits in the same order as they appear in the vector but that isn't a hard requirement for this algorithm (it would simplify things when dealing with the code that the return value is used by but I'll have to modify it regardless).

At this point the solution is sufficiently fast and compact for the purposes I need it for but I thought it might be worth asking in case anyone has a better solution, since a faster version if one exists wouldn't go unappreciated


Solution

  • Using the method suggested by @Peter Cordes of selectively shifting the lower of each pair of elements so that each pair of elements are effectively concatenated in the centre of each 2-element pair at each step and using pairwise addition, I was able to marginally improve upon the speed of my original code, while also making it concatenate the bits in vector order.

    The act of then manipulating the offsets so that in the next step the lower elements would be shifted left and the upper elements right so that they would again meet in the middle ended up being more expensive than simply negating the original lower pair's offset and shifting the concatenated bits right to reset them for the next step. I was fairly confused attempting to work that process out so there may be room for improvement in that direction.

    Regardless, the improved method is very slightly faster (~0.2ns faster in microbenchmarks on an Apple M2) but takes an extra 2 instructions in the generated assembly. It does also have the advantage of concatenating the bits in order, which the original method does not.

    Improved per-step code below, the main reduce method is the same:

    inline std::pair<uint16x8_t, uint16x8_t> reduce_step_u8(uint8x16_t bits, uint8x16_t bit_counts)
    {
        const uint8x16_t mask = vzip1q_u8(vdupq_n_u8(0xFF), vdupq_n_u8(0x00));
        const uint8x16_t shift = vdupq_n_u8(8) - bit_counts;
        //the only bits that matter for the shift is the sign bit and the
        //3 least significant bits, so the fact that the upper 8 bits of each
        //16 bit shift value are non-zero does not matter
        const int16x8_t  reverse_shift = vdupq_n_s16(0) - vreinterpretq_u8_s16(shift);
    
        const uint8x16_t shifted_bits = vshlq_u8(bits, vandq_u8(shift, mask));
        
        return std::pair
        {
            vshlq_u16(vreinterpretq_u8_u16(shifted_bits), reverse_shift),
            vpaddlq_u8(bit_counts),
        };
    }