Search code examples
pythonoptimizationjitloop-unrolling

Why does 2x2 loop unrolling run slower in python (but not when compiled with jit nopython)


These functions provide identical results (assuming an even length array). The 2x2 unrolled function, however runs 30% slower when fed 10,000,000 floats. When I change the functions to run in nopython mode I see approximately a 2.5x speedup with the unrolled function.

Why does normal python fail to gain a speedup from loop unrolling? What's the underlying difference in how these two versions of the code are compiled and run?

# @jit(nopython=True)
def func1(array):
    numSum = 0

    for i in range(array.shape[0]):
        numSum += array[i]

    return numSum

# 2x2 unrolled
# @jit(nopython=True)
def unrolledFunc1(array):
    numSum0 = 0
    numSum1 = 0 

    for i in range(0, array.shape[0]-1, 2):
        numSum0 += array[i] 
        numSum1 += array[i+1] 

    numSum = numSum0 + numSum1 

    return numSum

Solution

  • Firstly you're doing loop unrolling wrong (search for duff's device).

    Secondly, Python has A LOT of overhead when accessing array elements (out-of-bounds check, wraparound check, etc, etc.) which is where most of the program time is spent, not the actual addition.

    Thirdly, you can use dis.dis or other Python disassembler tools to see the generated byte codes and see what exactly the differences are between the unrolled and normal loop (not much difference in most cases since the bulk of the CPU time is "wasted" performing checks).

    Lastly, loop unrolling will not have ANY improvements on modern hardware! You may be surprised but a for loop is actually a while loop, which in machine code is something like:

    loop_begin:
        i = 0
        target = 100
    repeated_section:
        if i < target:
            # DO SOME PROCESSING
            i = i+1
            jump to repeated_section
    

    If you don't believe me, search for Python while else. Modern hardware is tuned to predict branches and do as much processing as it can ahead of time. Since in a for loop, the branch is extremely well predictable (it's always true except for the very last iteration), the branch predictor optimizes your loop without you realizing it. I'm not sure how exactly you ran your benchmarks but I used Cython and wrote a bunch of variations for loop unrolling (manually unrolling 100 loops, using Duff's device and normal loop) and they all had very similar run times (within the margin of error of each other), especially if you use the correct compiler flags (-O3 or -Ofast for GCC on Linux or /O2i for visual studio c++ compiler on windows).

    If you're using black magic (@jit) I wouldn't worry about the performance differences from their generated code. You should try to make the jit compiler understand what you're trying to do, so it can optimize it for you. That being said if you're truly trying to optimize code rather than just trying jit and checking what sticks, you need to control the actual source code, compiler flags, etc. BEFORE running benchmarks so you can get a sense of what your change did AND YOU MUST USE A PROFILER to see WHAT IS ACTUALLY THE BOTTLENECK in your code, not guess and check.

    This is explained similarly here: https://stackoverflow.com/a/2349219/22723166

    Edit: I ran a quick sanity check on your specific example (just calculating the sum). Keep in mind this only works for arrays that are exact multiples of N (N being the number of unrolled lines) since I didn't bother with a proper duff's device; this is just meant to be a sanity check:

    cpdef  loop_sum(double [::1] arr):
        cdef unsigned int i = 0
        cdef double total = 0
        for i in range(arr.shape[0]):
            total += arr[i]
        return total
    
    cpdef  unroll_10_sum(double [::1] arr):
        cdef unsigned int i = 0, shape = arr.shape[0]
        cdef double total0 = 0, total1 = 0, total2 = 0, total3 = 0, total4 = 0, total5 = 0, total6 = 0, total7 = 0, total8 = 0, total9 = 0
        for i in range(0,shape,10):
            total0 += arr[i]
            total1 += arr[i+1]
            total2 += arr[i+2]
            total3 += arr[i+3]
            total4 += arr[i+4]
            total5 += arr[i+5]
            total6 += arr[i+6]
            total7 += arr[i+7]
            total8 += arr[i+8]
            total9 += arr[i+9]
        return total0+total1+total2+total3+total4+total5+total6+total7+total8+total9
       
    cpdef  unroll_100_sum(double [::1] arr):
        cdef unsigned int i = 0, shape = arr.shape[0]
        cdef double total0 = 0, total1 = 0, total2 = 0, total3 = 0, total4 = 0, total5 = 0, total6 = 0, total7 = 0, total8 = 0, total9 = 0, total10 = 0, total11 = 0, total12 = 0, total13 = 0, total14 = 0, total15 = 0, total16 = 0, total17 = 0, total18 = 0, total19 = 0, total20 = 0, total21 = 0, total22 = 0, total23 = 0, total24 = 0, total25 = 0, total26 = 0, total27 = 0, total28 = 0, total29 = 0, total30 = 0, total31 = 0, total32 = 0, total33 = 0, total34 = 0, total35 = 0, total36 = 0, total37 = 0, total38 = 0, total39 = 0, total40 = 0, total41 = 0, total42 = 0, total43 = 0, total44 = 0, total45 = 0, total46 = 0, total47 = 0, total48 = 0, total49 = 0, total50 = 0, total51 = 0, total52 = 0, total53 = 0, total54 = 0, total55 = 0, total56 = 0, total57 = 0, total58 = 0, total59 = 0, total60 = 0, total61 = 0, total62 = 0, total63 = 0, total64 = 0, total65 = 0, total66 = 0, total67 = 0, total68 = 0, total69 = 0, total70 = 0, total71 = 0, total72 = 0, total73 = 0, total74 = 0, total75 = 0, total76 = 0, total77 = 0, total78 = 0, total79 = 0, total80 = 0, total81 = 0, total82 = 0, total83 = 0, total84 = 0, total85 = 0, total86 = 0, total87 = 0, total88 = 0, total89 = 0, total90 = 0, total91 = 0, total92 = 0, total93 = 0, total94 = 0, total95 = 0, total96 = 0, total97 = 0, total98 = 0, total99 = 0
        for i in range(0,shape,100):
            total0 += arr[i]
            total1 += arr[i+1]
            total2 += arr[i+2]
            total3 += arr[i+3]
            total4 += arr[i+4]
            total5 += arr[i+5]
            total6 += arr[i+6]
            total7 += arr[i+7]
            total8 += arr[i+8]
            total9 += arr[i+9]
            total10 += arr[i+10]
            total11 += arr[i+11]
            total12 += arr[i+12]
            total13 += arr[i+13]
            total14 += arr[i+14]
            total15 += arr[i+15]
            total16 += arr[i+16]
            total17 += arr[i+17]
            total18 += arr[i+18]
            total19 += arr[i+19]
            total20 += arr[i+20]
            total21 += arr[i+21]
            total22 += arr[i+22]
            total23 += arr[i+23]
            total24 += arr[i+24]
            total25 += arr[i+25]
            total26 += arr[i+26]
            total27 += arr[i+27]
            total28 += arr[i+28]
            total29 += arr[i+29]
            total30 += arr[i+30]
            total31 += arr[i+31]
            total32 += arr[i+32]
            total33 += arr[i+33]
            total34 += arr[i+34]
            total35 += arr[i+35]
            total36 += arr[i+36]
            total37 += arr[i+37]
            total38 += arr[i+38]
            total39 += arr[i+39]
            total40 += arr[i+40]
            total41 += arr[i+41]
            total42 += arr[i+42]
            total43 += arr[i+43]
            total44 += arr[i+44]
            total45 += arr[i+45]
            total46 += arr[i+46]
            total47 += arr[i+47]
            total48 += arr[i+48]
            total49 += arr[i+49]
            total50 += arr[i+50]
            total51 += arr[i+51]
            total52 += arr[i+52]
            total53 += arr[i+53]
            total54 += arr[i+54]
            total55 += arr[i+55]
            total56 += arr[i+56]
            total57 += arr[i+57]
            total58 += arr[i+58]
            total59 += arr[i+59]
            total60 += arr[i+60]
            total61 += arr[i+61]
            total62 += arr[i+62]
            total63 += arr[i+63]
            total64 += arr[i+64]
            total65 += arr[i+65]
            total66 += arr[i+66]
            total67 += arr[i+67]
            total68 += arr[i+68]
            total69 += arr[i+69]
            total70 += arr[i+70]
            total71 += arr[i+71]
            total72 += arr[i+72]
            total73 += arr[i+73]
            total74 += arr[i+74]
            total75 += arr[i+75]
            total76 += arr[i+76]
            total77 += arr[i+77]
            total78 += arr[i+78]
            total79 += arr[i+79]
            total80 += arr[i+80]
            total81 += arr[i+81]
            total82 += arr[i+82]
            total83 += arr[i+83]
            total84 += arr[i+84]
            total85 += arr[i+85]
            total86 += arr[i+86]
            total87 += arr[i+87]
            total88 += arr[i+88]
            total89 += arr[i+89]
            total90 += arr[i+90]
            total91 += arr[i+91]
            total92 += arr[i+92]
            total93 += arr[i+93]
            total94 += arr[i+94]
            total95 += arr[i+95]
            total96 += arr[i+96]
            total97 += arr[i+97]
            total98 += arr[i+98]
            total99 += arr[i+99]
        return total0+total1+total2+total3+total4+total5+total6+total7+total8+total9+total10+total11+total12+total13+total14+total15+total16+total17+total18+total19+total20+total21+total22+total23+total24+total25+total26+total27+total28+total29+total30+total31+total32+total33+total34+total35+total36+total37+total38+total39+total40+total41+total42+total43+total44+total45+total46+total47+total48+total49+total50+total51+total52+total53+total54+total55+total56+total57+total58+total59+total60+total61+total62+total63+total64+total65+total66+total67+total68+total69+total70+total71+total72+total73+total74+total75+total76+total77+total78+total79+total80+total81+total82+total83+total84+total85+total86+total87+total88+total89+total90+total91+total92+total93+total94+total95+total96+total97+total98+total99
    

    results: enter image description here The numbers add up correctly since 10 million is a multiple of both 10 and 100, so the functionality checks out. Now to performance:

    You may want to jump to conclusions and say: "WOW, unrolling increased speed by 2 to 2.5x! let's use it everywhere!" The problem is the simplicity of the function. It's just one addition! no switch statements, no multiplications, nothing! By the time you add those, the 2.5x performance gain is gone too. Your efforts will be much better spent in trying to parallelize the code to run on multiple threads. Have a look at this post to see a very similar 2.5x performance gain with multi-threading: Increasing CPU usage with Python during for loops The advantage of the multi-threaded approach is that you can actually use your resources (CPU and memory bandwidth) rather than relying on decades-old tricks with questionable benefits on modern hardware.

    Edit2: Here is proof for unrolling not making a difference in a very slightly different use case: Imagine instead of just wanting to calculate the sum, you want the calculate sum of the cubes for the numbers in the array. We can modify the functions like:

    cpdef  loop_sum(double [::1] arr):
        cdef unsigned int i = 0
        cdef double total = 0
        for i in range(arr.shape[0]):
            total += arr[i]**3
        return total
    
    cpdef  unroll_10_sum(double [::1] arr):
        cdef unsigned int i = 0, shape = arr.shape[0]
        cdef double total0 = 0, total1 = 0, total2 = 0, total3 = 0, total4 = 0, total5 = 0, total6 = 0, total7 = 0, total8 = 0, total9 = 0
        for i in range(0,shape,10):
            total0 += arr[i]**3
            total1 += arr[i+1]**3
            total2 += arr[i+2]**3
            total3 += arr[i+3]**3
            total4 += arr[i+4]**3
            total5 += arr[i+5]**3
            total6 += arr[i+6]**3
            total7 += arr[i+7]**3
            total8 += arr[i+8]**3
            total9 += arr[i+9]**3
        return total0+total1+total2+total3+total4+total5+total6+total7+total8+total9
    

    Unsurprisingly enough all methods now have the same execution time since power or multiplication are not single-cycle instructions and cost way more than simple addition. enter image description here The moral of the story remains the same: Don't try to be too smart on modern hardware, most of the "tricks" you read about were designed decades ago and make no practical difference today. The only arcane knowledge that still holds true is branchless programming in my experience.