Search code examples
countcomparisonsseintrinsicssse2

SSE2 intrinsics - comparing 2 __m128i's containing 4 int32's each to see how many are equal


I'm diving in SSE2 intrinsics for the first time and I'm not sure how to do this.

I want to compare 4 int32's to 4 other int32's and count how many are equal. So I read my first 4 int32's, set them in a __m128i, do the same for the second set, and use _mm_cmpeq_epi32 for the comparison.

This should result in a __m128i containing 4 int32's, each one either 0xffffffff or 0 depending on whether the ints were equal.

But I have no idea how to get from that resulting __m128i to a count specifying how many were actually equal.

Can anyone point me in the right direction ?

The code as far as I'm piecing it together :

        int* source = blah;
        int* reference = otherblah;

        // Load the 4 source int32's (they are actually 4 int32s apart)
        __m128i first_4_int32s = _mm_set_epi32(*(source + 12), *(source + 8), *(source + 4), *(source));

        // Load the 4 source int32's (also actually 4 int32s apart)
        __m128i second_4_int32s = _mm_set_epi32(*(reference + 12), *(reference + 8), *(reference + 4), *(reference));

        // Compare the int32's
        __m128i result = _mm_cmpeq_epi32(first_4_int32s, second_4_int32s);

        // Perform magic here that counts whether 0, 1, 2, 3 or all 4 ints were equal ?!?!

Solution

  • You can AND the compare result with a vector of ones to create a vector of zeros and ones. Then use a horizontal add operation to count the ones. Here are some possibilities.

    #include "stdio.h"
    #include "stdint.h"
    #include "intrin.h"
    
    //----------------------------------------------------------------------------
    // non-SSE method (reference for result check)
    static int method0 (__m128i value)
        {
        int index, total = 0;
        uint32_t *buffer = (void *) &value;
    
        for (index = 0; index < 4; index++)
            total += buffer [index] == 0xFFFFFFFF;
        return total;
        }
    
    //----------------------------------------------------------------------------
    //
    // horizontalAddBytes - return integer total of all 16 bytes in xmm argument
    //
    static int horizontalAddBytes (__m128i byteArray)
       {
       __m128i total;
       const __m128i zero = _mm_setzero_si128 ();
    
       total = _mm_sad_epu8 (byteArray, zero);
       return _mm_cvtsi128_si64 (_mm_add_epi32 (total, _mm_shuffle_epi32 (total, 0xAA)));
       }
    
    //----------------------------------------------------------------------------
    // requires SSE2
    static int method1 (__m128i value)
        {
        return horizontalAddBytes (_mm_srli_epi32 (value, 31));
        }
    
    //----------------------------------------------------------------------------
    // requires SSE3
    static int method2 (__m128i value)
        {
        __m128 count;
        const __m128 mask = _mm_set1_ps (1);
        count = _mm_and_ps (_mm_castsi128_ps (value), mask);
        count = _mm_hadd_ps (count, count);
        count = _mm_hadd_ps (count, count);
        return _mm_cvtss_si32 (count);
        }
    
    //----------------------------------------------------------------------------
    // requires SSSE3
    static int method3 (__m128i value)
        {
        __m128i count;
        count = _mm_srli_epi32 (value, 31);
        count = _mm_hadd_epi32 (count, count);
        count = _mm_hadd_epi32 (count, count);
        return _mm_cvtsi128_si32 (count);
        }
    
    //----------------------------------------------------------------------------
    
    static void createTestData (uint32_t *data, int mask)
        {
        int index;
        for (index = 0; index < 4; index++)
            data [index * 4] = (mask & (1 << index)) != 0;
        }
    
    //----------------------------------------------------------------------------
    
    int main (void)
        {
        int index1, index2, expected, result1, result2, result3;
        uint32_t source [16];
        uint32_t reference [16];
    
        for (index1 = 0; index1 < 16; index1++)
            for (index2 = 0; index2 < 16; index2++)
                {
                __m128i first_4_int32s, second_4_int32s, result;
                createTestData (source, index1);
                createTestData (reference, index2);
    
                // Load the 4 source int32's (they are actually 4 int32s apart)
                first_4_int32s = _mm_set_epi32(*(source + 12), *(source + 8), *(source + 4), *(source));
    
                // Load the 4 source int32's (also actually 4 int32s apart)
                second_4_int32s = _mm_set_epi32(*(reference + 12), *(reference + 8), *(reference + 4), *(reference));
    
                // Compare the int32's
                result = _mm_cmpeq_epi32(first_4_int32s, second_4_int32s);
    
                expected = method0 (result);
                result1 = method1 (result);
                result2 = method2 (result);
                result3 = method3 (result);
                if (result1 != expected) printf ("method1, index %d,%d expected %d, actual %d\n", index1, index2, expected, result1);
                if (result2 != expected) printf ("method2, index %d,%d expected %d, actual %d\n", index1, index2, expected, result2);
                if (result3 != expected) printf ("method3, index %d,%d expected %d, actual %d\n", index1, index2, expected, result3);
                }
    
        return 0;
        }
    
    //----------------------------------------------------------------------------