Search code examples
assemblyoptimizationclangllvmarm64

How does clang generate non-looping code for sum of squares?


I admit the answer to this may be 'some very specific magic', but I'm kind of shocked by what I've observed here. I was wondering if anyone had insight to how these types of optimizations work. I find compiler design to be quite interesting, and I really can't imagine how this works. I'm sure the answer is somewhere in the clang source code, but I don't even know where I would look.

I'm a TA for a class at college, and I was recently asked to help with a simple homework question. This led me down an interesting path...

The question is simple enough: In x86_64 assembly, write a function which given a (positive) integer n returns 1^2 + 2^2 + 3^2 + ... + n^2.

I decided to play around a bit, and after helping them write this in x86_64 assembly, I, having an M1 macbook, decided to see if I could create a nice solution in arm64 assembly. I came up with the relatively simple and straightforward solution:

_sum_squares:
    mov x1, x0  ; Do multiplication from x1
    mov x0, xzr ; Clear x0

    Lloop:
        ; x0 <-- (x1 * x1) + x0
        madd x0, x1, x1, x0

        ; Loop until x1 == 0
        subs x1, x1, #1
        bne Lloop

    ret

(I wish there was some sort of nice way to do branch if --x1 == 0 in one instruction, but I couldn't think of any)

Note: There is a simple formula for this from any basic number theory class, which is [n(n + 1)(2n + 1)] / 6, but I decided this wasn't really in the spirit of the question.

I then was wondering how clang would generate assembly for a simple C version. Upon writing the simple C implementation, I found that clang with -Og generates assembly which seems a bit verbose, but generally works as expected with a loop and accumulator (although it is very inefficient):

int sum_squares(int n)
{
    int a = 0;

    while (n--)
        a += (n * n);

    return a;
}

(clang -Og -S, annotated myself, cfi removed, labels renamed)

_sum_squares:
    sub sp, sp, #16         ; create stack space
    str w0, [sp, #12]       ; store n
    str wzr, [sp, #8]       ; store 0
    b   Ldec                ; silly clang, this just falls through...

Ldec:                       ; n-- and return if n == 0
    ldr w8, [sp, #12]       ; load n
    subs    w9, w8, #1      ; w9 = (n - 1)
    str w9, [sp, #12]       ; store (n - 1) over n
    subs    w8, w8, #0      ; w8 = n - 0 (set flags based on n)
    cset    w8, eq          ; set w8 = 1 if n == 0 else w8 = 0
    tbnz    w8, #0, Lret    ; branch to return if n == 0, else fall through
    b   Ladd                ; silly clang, this falls through again...

Ladd:                       ; a += n^2
    ldr w8, [sp, #12]       ; load n
    ldr w9, [sp, #12]       ; load n
    mul w9, w8, w9          ; w9 = n * n
    ldr w8, [sp, #8]        ; load a
    add w8, w8, w9          ; a += w9
    str w8, [sp, #8]        ; store a
    b   Ldec                ; go back to start of look

Lret:                       ; return a from top of stack
    ldr w0, [sp, #8]        ; w0 = a
    add sp, sp, #16         ; cleanup temp stack
    ret                     ; back to caller

This is altogether reasonable for a direct translation of the C code to arm64 assembly. After some optimization (O1 uses a similar formula, O2 and O3 are identical), clang comes up with some magic. I have no clue how it came up with this code, it appears to be somewhat similar to the basic formula for this summation, except with bit magic. I didn't imagine the compiler would be able to derive a formula for this without a loop, but it appears I was wrong. The generated code is as follows (with my best attempt at a commentary, n is the input in w0):

_sum_squares:
        cbz     w0, Lret             ; return if n == 0

        sub     w8, w0, #1           ; w8 = (n - 1)
        mul     w9, w8, w8           ; w9 = (n - 1)^2
        orr     w9, w9, #0x2         ; w9 = ((n - 1)^2) | 2
        sub     w9, w9, w0           ; w9 = [((n - 1)^2) | 2] - n

        mov     w10, #5              ; w10 = 5
        sub     w10, w10, w0, lsl #1 ; w10 = 5 - (n / 2)

        sub     w11, w0, #2          ; w11 = n - 2
        umull   x11, w8, w11         ; w11 = (n - 1)(n - 2)

        lsr     x12, x11, #1         ; x12 = ((n - 1)(n - 2)) / 2
        mul     w10, w10, w12        ; w10 = (5 - (n / 2))(((n - 1)(n - 2)) / 2)

        sub     w12, w0, #3          ; w12 = n - 3
        mul     x11, x11, x12        ; x11 = (n - 1)(n - 2)(n - 3)
        lsr     x11, x11, #1         ; x11 = ((n - 1)(n - 2)(n - 3)) / 2

        mov     w12, #21846          ; w12 = 0x5556
        movk    w12, #21845, lsl #16 ; w12 = 0x55555556

        ; w8 = ((n - 1)([((n - 1)^2) | 2] - n)) + (5 - (n / 2))(((n - 1)(n - 2)) / 2)
        madd    w8, w9, w8, w10

        ; let A = w8 (set in last instruction)
        ; w0 = (0x55555556 * (((n - 1)(n - 2)(n - 3)) / 2)) + A
        madd    w0, w11, w12, w8
        ; somehow, this is the correct result?
        ; this feels like magic to me...

Lret:
        ret                          ; return. Result already in w0.

My question: How in the world does this work? How can a C compiler be given a loop like this and deduce a formula not even involving a loop? I expected some loop unwinding perhaps, but nothing like this. Does anyone have references involving this type of optimization?

I especially don't understand what certain steps like orr w9, w9, #0x2 or the magic number 0x55555556 does. Any insight into these steps would be extra appreciated.


Solution

  • TL:DR: Yes, clang knows the closed-form formulas for sums of integer power series, and can detect such loops. Smart humans have taught modern compilers to recognize certain patterns of operations and replace them with operations not present in the source, e.g. for rotates and even popcount loops and bithacks. And for clang/LLVM specifically, also closed-form formulae for sums of i^power, including with a stride other than 1. Yay math! So you can get asm logic that's not just an unrolled or vectorized version of the source.

    See also a blog article How LLVM optimizes power sums which talks about how compilers find these loops by looking at how variables are updated across loop iterations.

    Matthieu M. comments that Closed form formulas are derived by the Scalar Evolution optimization in LLVM. The comments in the code say that it's used primarily to analyze expressions involving induction variables in loops. and cites references for the techniques it uses for chains of recurrences.


    Modern C compilers can recognize patterns in some loops or short sequences of logic, in the internal representation of the code. Humans (compiler devs) have told the compiler what to look for, and provided a hand-crafted replacement "formula". In GIMPLE (GCC) or LLVM-IR I expect, not just really late in compilation like a peephole optimization while generating asm.

    So I'd guess the logic inside LLVM's optimizer checks every loop it finds for one or more of the following possibilities, with some code to look for some property of the LLVM-IR that represents the program logic of that loop:

    • Does it copy one array to another unmodified? If so replace with __builtin_memcpy, which might later get expanded inline or compiled to call memcpy. And if it has other side effects like leaving a pointer variable incremented, also represent that in the new LLVM-IR for the function containing the loop.
    • Does it set every byte of a range of memory to a constant byte value? If so, memset
    • Is its sequence of operations equivalent to this sequence which does popcnt? Then emit a popcnt if hardware support exists, otherwise keep the loop strategy. (So it's not just treating it as if it was __builtin_popcount, not replacing a loop with a bithack or a call to a helper function. That makes sense because some strategies are good for numbers with few bits set, and the programmer might have chosen with that in mind.)
    • Is the loop variable updated with the sum of a range of integers (with some stride), or that raised to a power? Then use a closed-form formula that avoids overflow of a fixed-width integer. (And if the start point and stride aren't 1, add an offset and/or scale factor.)

    The checking might work in terms of considering a variable modified by a loop, which is read after the loop. So it knows what variable to consider when looking at the operations. (Loops with no used results get removed.)

    GCC doesn't look for sums of integer sequences, but clang does. IDK how many real-world codebases this actually speeds up; the closed-form formula is fairly well-known, having famously been re-discovered by Gauss as a schoolboy. (So hopefully a lot of code uses the formula instead of a loop). And not many programs would need to do exactly this, I'd have thought, except as an exercise.

    (The existence of a closed-form sum-of-squares formula is less well-known, but there is one, and apparently also for powers in general.)


    Clang's implementation of the formula of course has to give the exact correct result for every input integer where the C abstract machine doesn't encounter undefined behaviour (for signed integer overflow), or match the truncation of unsigned multiplies. Otherwise it wouldn't satisfy the as-if rule, or could only be used when inlining into places with a known limited value-range. (In practice, it seemed clang wasn't using the closed-form optimization for unsigned, but maybe I just had a mistake in the version I was trying. Using a 64-bit integer could safely calculate sums of 32-bit integers. And then truncating that could give the same result as the source.)

    n*(n+1) can overflow in cases where n*(n+1)/2 is still in range, so this is non-trivial. For 32-bit int on a 64-bit machine, LLVM can and does simply use 64-bit multiply and right-shift. This may be a peephole optimization of the general case of using a double-width output and an extended-precision right shift, across two registers if the product didn't fit in one. (e.g. x86 shrd edx, eax, 1 to shift the low bit from the high half into the top of EAX, after a mul r32 produced the 64-bit product in EDX:EAX.)

    It also does n * (n-1) / 2 + n instead of the usual n * (n+1)/2; not sure how that helps. It avoids overflow of the input type, I guess, in case that matters for unsigned types where the original loop would just have wrapping, not UB. Except it doesn't do this optimization for unsigned. (BTW, either n or n+-1 are even, so the division (right shift) is exact; which is good because the sum of integers had better be an integer.)

    In your sum-of-squares asm, you can see it using umull x, w, w to do a widening multiply, and a 64-bit right shift, before the 32-bit multiplicative-inverse for division by 3.


    Playing around with your code and a simplified version not squaring, it makes a small difference in code-gen when you count down or up.

    int sum_ints(int n) {
        int a = 0;
        //for (int i=0 ; i<n ; i++)  a += i;        // count up, skipping n
        while (n--) a += n;                      // count down, skipping n
        return a;
    }
    

    Negative n would have UB with your version, as the loop would run to INT_MIN--, and overflow a first. So clang might or might not be using that to assume that the initial n is non-negative. But if not, IDK why it makes more complicated code that multiplies twice.

    // count down version, starting with a += n-1, so x = n-1 in the usual formulae.
    //  clang15 -O3
    sum_ints(int):
            cbz     w0, .LBB0_2        // only bail on zero, not negative.
            sub     w8, w0, #1         // n-1
            sub     w9, w0, #2         // n-2
            umull   x10, w8, w9        // (n-1)*(n-2)
            madd    w8, w8, w9, w0     // w8 = (n-1)*(n-2) + n
            lsr     x9, x10, #1        // w9 = (n-1)*(n-2)/2
            mvn     w9, w9             // w9 = ~w9 = -w9 - 1
            add     w0, w8, w9         // (n-1)*(n-2) - (n-1)*(n-2)/2 + n - 1 I think?
    .LBB0_2:
            ret
    
    // count up version, ending with n-1.  clang15 -O3
    sum_ints(int):
            subs    w8, w0, #1       // n-1
            b.lt    .LBB0_2
            sub     w9, w0, #2       // n-2
            umull   x9, w8, w9       // (n-1)*(n-2)
            lsr     x9, x9, #1       // . / 2
            add     w0, w8, w9       // (n-1)*(n-2)/2 + (n-1) = (n-1)*(n-2 + 2)/2
                                     // = the usual               x  * (x+1   )/2 for x=n-1
            ret
    .LBB0_2:
            mov     w0, wzr         // separate return path for all negative inputs
            ret
    
    

    Other types of loop pattern-recognition / replacement

    GCC and clang do pattern-recognition for loops that count set bits, as well as the standard bithack that people will have copy/pasted from SO. (This is useful because ISO C fails to provide a portable way to express this operation that most modern CPUs have. And ISO C++ only fixed that deficiency in C++20 with <bit>, or via std::bitset<32> .count()). So some real codebases just have a bithack or simple loop over set bits instead of __builtin_popcount because people prefer simplicity and want to leave performance up to the compiler.

    These pattern-recognizers only work on some specific ways to implement popcount, namely the x &= x-1; count++; it would presumably cost too much compile time to try to prove equivalence for every possible loop. From that, we can be pretty sure that these work by looking for a specific implementation, not at what the result actually is for every possible integer.

    The variable names of course don't matter, but the sequence of operations on the input variable does. I assume there's some flexibility in reordering operations in ways that give the same result when checking for equivalence. In GCC's case, apparently number_of_iterations_popcount is the name of the function that discovers this: compilers often want to know how many iterations a loop will run for: if it's a small constant they may fully unroll it. If it can be calculated from other variables before starting the loop, it's a candidate for auto-vectorization. (GCC/clang can't auto-vectorize search loops, or anything else with a data-dependent if()break.)

    As shown in the top answer on Count the number of set bits in a 32-bit integer, GCC10 and clang10 (Godbolt) can also recognize a popcount using a SWAR bithack, so you get the best of both worlds: ideally a single instruction, but if not then at least a good strategy.

    Counting iterations of x &= x-1 until x == 0 is ok when the expected number of set bits is small, so is also a sensible choice sometimes, as the other thing that GCC / clang can replace if hardware popcount is available. (And is simple to write, without needing the masking constants, and can compile to small machine-code size with -Os if not being replaced with a single instruction.)

    int popcount_blsr_until_zero(unsigned x){
        int count = 0;
        while (x){
            count++;  // loop trip-count = popcount, this is what GCC looks for
            x &= x - 1;
        }
        return count;
    }
    

    GCC and clang for x86-64, -O3 -march=nehalem or later, on Godbolt for this and some other versions.

    # gcc12 -O3 -march=znver2
    popcount_blsr_until_zero(unsigned int):
            popcnt  eax, edi
            ret
    
    // clang -O3        for AArch64
    popcount_blsr_until_zero(unsigned int):
            mov     w8, w0            // this is pointless, GCC doesn't do it.
            fmov    d0, x8
            cnt     v0.8b, v0.8b      // ARM unfortunately only has vector popcnt
            uaddlv  h0, v0.8b         // hsum bytes
            fmov    w0, s0            // copy back to GP-integer
            ret
    

    One of the simplest forms of code replacement by pattern-recognition is compiling (n<<5) | (n>>(32-5)) into a rotate-left by 5. (See this Q&A for run-time variable counts, and how to safely write something that gets recognized but also avoids UB even for a count of 0.)

    But that might happen late enough in the compilation process that you'd call it a peephole optimization. CISC ISAs tend to have more peephole optimizations, like x86 having special-case shorter instructions to sign within the accumulator (cdqe instead of movzx eax, ax). x86 xor-zeroing to set a register to zero can still be called a peephole, despite sometimes needing to rearrange things because that clobbers FLAGS while mov eax, 0 doesn't.

    GCC enables xor-zeroing with -fpeephole2 (part of -O2); perhaps treating it as just a peephole is why GCC sometimes does a bad job and fails to find ways to reorder it so xor-zero / cmp / setcc instead of cmp / setcc / movzx, because x86 setcc to set a register according to a FLAGS condition sucks, only writing the low 8 bits. AArch64 has much better instructions, like csinc which can be used with the zero-register to materialize a 0/1, or with other registers to conditionally select and increment.

    But sum-of-series loops are a larger-scale replacement, not exactly what I'd think of as a peephole, especially since it's not target-specific.

    Also related:

    • Would a C compiler be allowed to replace an algorithm with another? - yes. But usually they don't, because compilers are mechanical enough that they wouldn't always be right, and picking an efficient algorithm for the data is something that C programmers would expect a compiler to respect, unless there's a single instruction that's obviously always faster.

    • clang knows how to auto-vectorize __builtin_popcount over an array with AVX2 vpshub as a lookup table for nibbles. It's not just making a SIMD version out of the same operation, again it's using an expansion that human compiler devs put there for it to use.

    • Why does compiler optimization not generate a loop for sum of integers from 1..N? is about cases where this optimization doesn't happen, e.g. one with j <= n for unsigned, which is a potentially infinite loop.

      Comments there found some interesting limitations on when clang can optimize: e.g. if the loop was for (int j=0 ; j < n; j += 3), the trip-count would be less predictable / calculable, and defeats this optimisation.