Search code examples
cassemblysseintrinsicsavx512

Fastest way to calculate a digit-sum for a large number (as a decimal string)


I use gmplib to get big number and I calculate the numeric value (sum of the digits: 123 -> 6, 74 -> 11 -> 2)

Here is what I did :

unsigned short getnumericvalue(const char *in_str)
{
    unsigned long number = 0;
    const char *ptr = in_str;
     
     do {
         if (*ptr != '9') number += (*ptr - '0'); // Exclude '9'
         ptr++;
     } while (*ptr != 0);
     
     unsigned short reduced = number % 9;
    
     return reduced == 0 ? 9 : reduced;
}

It works well but is there a faster way to do it on a Xeon w-3235?


Solution

  • You can use code like the following. The general idea of the algorithm is:

    1. process data bytewise until we reach cacheline alignment
    2. read one cacheline at a time, check for end of string, and add digits to 8 accumulators
    3. reduce 8 accumulators to one and add counts from head
    4. process remainder bytewise

    Note that the code below has not been tested.

            // getnumericvalue(ptr)
            .section .text
            .type getnumericvalue, @function
            .globl getnumericvalue
    getnumericvalue:
            xor %eax, %eax          // digit counter
    
            // process string until we reach cache-line alignment
            test $64-1, %dil        // is ptr aligned to 64 byte?
            jz 0f
    
    1:      movzbl (%rdi), %edx     // load a byte from the string
            inc %rdi                // advance pointer
            test %edx, %edx         // is this the NUL byte?
            jz .Lend                // if yes, finish this function
            sub $'0', %edx          // turn ASCII character into digit
            add %edx, %eax          // and add to counter
            test $64-1, %dil        // is ptr aligned to 64 byte?
            jnz 1b                  // if not, process more data
    
            // process data in cache line increments until the end
            // of the string is found somewhere
    0:      vpbroadcastd zero(%rip), %zmm1  // mask of '0' characters
            vpxor %xmm3, %xmm3, %xmm3       // vectorised digit counter
    
            vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
            vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
            kortestq %k0, %k0               // clear CF if a NUL byte is found
            jnc 0f                          // skip loop if a NUL byte is found
    
            .balign 16
    1:      add $64, %rdi                   // advance pointer
            vpsadbw %zmm1, %zmm0, %zmm0     // sum groups of 8 bytes into 8 words
                                            // also subtracts '0' from each byte
            vpaddq %zmm3, %zmm0, %zmm3      // add to counters
            vmovdqa32 (%rdi), %zmm0         // load one cache line from the string
            vptestmb %zmm0, %zmm0, %k0      // clear k0 bits if any byte is NUL
            kortestq %k0, %k0               // clear CF if a NUL byte is found
            jc 1b                           // go on unless a NUL byte was found
    
            // reduce 8 vectorised counters into rdx
    0:      vextracti64x4 $1, %zmm3, %ymm2  // extract high 4 words
            vpaddq %ymm2, %ymm3, %ymm3      // and add them to the low words
            vextracti128 $1, %ymm3, %xmm2   // extract high 2 words
            vpaddq %xmm2, %xmm3, %xmm3      // and add them to the low words
            vpshufd $0x4e, %xmm3, %xmm2     // swap qwords into xmm2
            vpaddq %xmm2, %xmm3, %xmm3      // and add to xmm0
            vmovq %xmm3, %rdx               // move digit counter back to rdx
            add %rdx, %rax                  // and add to counts from scalar head
    
            // process tail
    1:      movzbl (%rdi), %edx     // load a byte from the string
            inc %rdi                // advance pointer
            test %edx, %edx         // is this the NUL byte?
            jz .Lend                // if yes, finish this function
            sub $'0', %edx          // turn ASCII character into digit
            add %rdx, %rax          // and add to counter
            jnz 1b                  // if not, process more data
    
    .Lend:  xor %edx, %edx          // zero-extend RAX into RDX:RAX
            mov $9, %ecx            // divide by 9
            div %rcx                // perform division
            mov %edx, %eax          // move remainder to result register
            test %eax, %eax         // is the remainder zero?
            cmovz %ecx, %eax        // if yes, set remainder to 9
            vzeroupper              // restore SSE performance
            ret                     // and return
            .size getnumericvalue, .-getnumericvalue
    
            // constants
            .section .rodata
            .balign 4
    zero:   .byte '0', '0', '0', '0'