Search code examples
c++matrixsseavxavx2

Find largest element in matrix and its column and row indexes using SSE and AVX


I need to find the largest element in 1d matrix and its column and row indexes.

I use 1d matrix, so just finding the max element's index is needed first and then it is easy to get row and column.

My problem is that I cannot get that index.

I have a working function that finds largest element and uses SSE, here it is:

float find_largest_element_in_matrix_SSE(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m128 max_el = _mm_loadu_ps(m);
    __m128 curr;

    for (i = 4; i < dims * dims; i += 4)
    {
        curr = _mm_loadu_ps(m + i);
        max_el = _mm_max_ps(max_el, curr);
    }

    __declspec(align(16))float max_v[4] = { 0 };
    _mm_store_ps(max_v, max_el);

    return max(max(max(max_v[0], max_v[1]), max_v[2]), max_v[3]);
}

and also I have a non-working function that uses AVX:

float find_largest_element_in_matrix_AVX(float* m, unsigned const int dims)
{
    size_t i;
    int index = -1;
    __m256 max_el = _mm256_loadu_ps(m);
    __m256 curr;

    for (i = 8; i < dims * dims; i += 8)
    {
        curr = _mm256_loadu_ps(m + i);
        max_el = _mm256_max_ps(max_el, curr);
    }

    __declspec(align(32))float max_v[8] = { 0 };
    _mm256_store_ps(max_v, max_el);

    __m256 y = _mm256_permute2f128_ps(max_el, max_el, 1);
    __m256 m1 = _mm256_max_ps(max_el, y);m1[1] = max(max_el[1], max_el[3])
    __m256 m2 = _mm256_permute_ps(m1, 5); 
    __m256 m_res = _mm256_max_ps(m1, m2); 

    return m[0];
}

Could anyone help me with actually finding the index of the max element and make my AVX version work?


Solution

  • Here's a working SSE (SSE 4) implementation that returns the max val and corresponding index, along with a scalar reference implementation and test harness:

    #include <stdio.h>
    #include <stdint.h>
    #include <stdlib.h>
    #include <time.h>
    #include <smmintrin.h>  // SSE 4.1
    
    float find_largest_element_in_matrix_ref(const float* m, int dims, int *maxIndex)
    {
        float maxVal = m[0];
        int i;
    
        *maxIndex = 0;
    
        for (i = 1; i < dims * dims; ++i)
        {
            if (m[i] > maxVal)
            {
                maxVal = m[i];
                *maxIndex = i;
            }
        }
        return maxVal;
    }
    
    float find_largest_element_in_matrix_SSE(const float* m, int dims, int *maxIndex)
    {
        float maxVal = m[0];
        float aMaxVal[4];
        int32_t aMaxIndex[4];
        int i;
    
        *maxIndex = 0;
    
        const __m128i vIndexInc = _mm_set1_epi32(4);
        __m128i vMaxIndex = _mm_setr_epi32(0, 1, 2, 3);
        __m128i vIndex = vMaxIndex;
        __m128 vMaxVal = _mm_loadu_ps(m);
    
        for (i = 4; i < dims * dims; i += 4)
        {
            __m128 v = _mm_loadu_ps(&m[i]);
            __m128 vcmp = _mm_cmpgt_ps(v, vMaxVal);
            vIndex = _mm_add_epi32(vIndex, vIndexInc);
            vMaxVal = _mm_max_ps(vMaxVal, v);
            vMaxIndex = _mm_blendv_epi8(vMaxIndex, vIndex, _mm_castps_si128(vcmp));
        }
        _mm_storeu_ps(aMaxVal, vMaxVal);
        _mm_storeu_si128((__m128i *)aMaxIndex, vMaxIndex);
        maxVal = aMaxVal[0];
        *maxIndex = aMaxIndex[0];
        for (i = 1; i < 4; ++i)
        {
            if (aMaxVal[i] > maxVal)
            {
                maxVal = aMaxVal[i];
                *maxIndex = aMaxIndex[i];
            }
        }
        return maxVal;
    }
    
    int main()
    {
        const int dims = 1024;
        float m[dims * dims];
        float maxVal_ref, maxVal_SSE;
        int maxIndex_ref = -1, maxIndex_SSE = -1;
        int i;
    
        srand(time(NULL));
    
        for (i = 0; i < dims * dims; ++i)
        {
            m[i] = (float)rand() / RAND_MAX;
        }
    
        maxVal_ref = find_largest_element_in_matrix_ref(m, dims, &maxIndex_ref);
        maxVal_SSE = find_largest_element_in_matrix_SSE(m, dims, &maxIndex_SSE);
    
        if (maxVal_ref == maxVal_SSE && maxIndex_ref == maxIndex_SSE)
        {
            printf("PASS: maxVal = %f, maxIndex = %d\n",
                          maxVal_ref, maxIndex_ref);
        }
        else
        {
            printf("FAIL: maxVal_ref = %f, maxVal_SSE = %f, maxIndex_ref = %d, maxIndex_SSE = %d\n",
                          maxVal_ref, maxVal_SSE, maxIndex_ref, maxIndex_SSE);
        }
        return 0;
    }
    

    Compile and run:

    $ gcc -Wall -msse4 Yakovenko.c && ./a.out 
    PASS: maxVal = 0.999999, maxIndex = 120409
    

    Obviously you can get the row and column indices if needed:

    int rowIndex = maxIndex / dims;
    int colIndex = maxIndex % dims;
    

    From here it should be fairly straightforward to write an AVX2 implementation.