Search code examples
cassemblysimdintrinsicsavx

How to convert this assembly code to intrinsic code?


Below it seems like intrinsics, however, I am not familiar with intrinsic functions. Please help me to convert the real code. Especially, testFunc() is more ambiguous for me. I guess it is also for dot product of two float vectors, but, the labels Lrep and Lexit make me confuse. Please figure out clearly for me. And intrinsics are available for mobile processor?

void testFunc(int M, int N, int K, float* A, float* B, float* C)
{
    float *a;
    float *b = new float[K*N];
    float *pointb = B;
    float *bb;
    float *answer = C;
    float c[8];

    for (int j = 0, k; j < K; j++) {
        bb = b + j;
        for (k = N / 8; k > 0; k--) {
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
            *bb = *pointb++; bb += K;
        }
        for (k = N / 8 * 8; k < N; k++) {
            *bb = *pointb++; bb += K;
        }
    }

    int K8 = K / 8 * 8;

    for (int i = 0; i < M; i++) for (int k = 0; k < N; k++) {
        a = A + i * K;
        bb = b + k * K;
        __asm {
            mov             esi, K8;
            sub             esi, 8;
            shl             esi, 2;
            xor             edi, edi;
            mov             edx, a;
            mov             ebx, bb;
            vxorps          ymm3, ymm3, ymm3;
        Lrep:
            cmp             edi, esi;
            jg              Lexit;
            vmovups         ymm0, ymmword ptr[edx + edi];
            vfmadd231ps     ymm3, ymm0, ymmword ptr[ebx + edi];
            add             edi, 32;
            jmp             Lrep;
        Lexit:
            vmovups         ymmword ptr[c], ymm3;
        }

        for (int j = K8; j < K; ) {
            *c += *(a + j) * *(bb + j); j++;
        }

        *answer = (c[0] + c[1] + c[2] + c[3] + c[4] + c[5] + c[6] + c[7]);
        answer++;
    }
}

and

pA = A;
for (k = 0; k < K; k++) {
    pC = C;
    for (i = 0; i < M; i++) {
        pA = A + i * K + k;
        pB = B + k * N;
        for (j = N / 32; j > 0; j--) {
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
            _asm {
                mov             eax, pC;
                mov             ebx, pA;
                mov             ecx, pB;
                vmovups         ymm0, ymmword ptr[eax];
                vmovss          xmm1, dword ptr[ebx];
                vbroadcastss    ymm4, xmm1;
                vmovups         ymm2, ymmword ptr[ecx];
                vfmadd231ps     ymm0, ymm4, ymm2;
                vmovups         ymmword ptr[eax], ymm0;
            }
            pC += 8; pB += 8;
        }
        for (j = N / 32 * 32; j < N; j++) {
            *pC += *pA * *pB;
            pC += 1; pB += 1;
        }
    }
}

Solution

  • In intrinsics, it's this code repeated 4 times.

    {
    // vmovups         ymm0, ymmword ptr[eax];
    __m256 tempC = _mm256_loadu_ps((float*)pC);
    
    // vmovss          xmm1, dword ptr[ebx];
    // vbroadcastss    ymm4, xmm1;
    __m256 tempA = _mm256_set1_ps(*pA);
    
    // vmovups         ymm2, ymmword ptr[ecx];
    __m256 tempB = _mm256_loadu_ps((float*)pB);
    
    // vfmadd231ps     ymm0, ymm4, ymm2;
    __m256 result = _mm256_fmadd_ps(tempA, tempB, tempC);
    
    // vmovups         ymmword ptr[eax], ymm0;
    _mm256_storeu_ps(pC, result);
    }
    
    pC += 8; pB += 8;
    

    Constantly broadcasting the same value from pA seems a bit redundant though.