Search code examples
pythonperformancebuilt-inbranchless

Why are the branchless and built-in functions slower in Python?


I found 2 branchless functions that find the maximum of two numbers in python, and compared them to an if statement and the built-in max function. I thought the branchless or the built-in functions would be the fastest, but the fastest was the if-statement function by a large margin. Does anybody know why this is? Here are the functions:

If-statement (2.16 seconds for 25000 operations):

def max1(a, b):
    if a > b:
        return a
    return b

Built-in (4.69 seconds for 25000 operations):

def max2(a, b):
    return max(a, b)

Branchless 1 (4.12 seconds for 25000 operations):

def max3(a, b):
    return (a > b) * a + (a <= b) * b

Branchless 2 (5.34 seconds for 25000 operations):

def max4(a, b):
    diff = a - b
    return a - (diff & diff >> 31)

Solution

  • Your expectations about branching vs. branchless code apply to low-level languages like assembly and C. Branchless code can be faster in low-level languages because it prevents slowdowns caused by branch prediction misses. (Note: this means branchless code can be faster, but it will not necessarily be.)

    Python is a high-level language. Assuming you are using the CPython interpreter: for every bytecode instruction you execute, the interpreter has to branch on the kind of opcode, and typically many other things. For example, even the simple < operator requires a branch to check for the < opcode, another branch to check whether the object's class implements a __lt__ method, more branches to check whether the right-hand-side value is of a valid type for the comparison to be performed, and probably several other branches. Even your so-called "branchless" code will in practice result in a lot of branching for these reasons.

    Because Python is so high-level, each bytecode instruction is actually doing quite a lot of work compared to a single machine-code instruction. So the performance of simple code like this will mainly depend on how many bytecode instructions have to be interpreted:

    • Your max1 function has to do three loads of local variables, a comparison, a conditional jump and a return. That's six.
    • Your max2 function does two loads of local variables, one load of a global variable (referencing the built-in max), and also makes a function call; that requires passing arguments, and is relatively expensive compared to other bytecode instructions. On top of that, the built-in function itself has to do the same work as your own max1 function, so no wonder max2 is slower.
    • Your max3 function does six loads of local variables, two comparisons, two multiplications, one addition, and one return. That's twelve instructions, so we should expect it to take about twice as long as max1.
    • Likewise max4 does five loads of local variables, one store to a local variable, one load of a constant, two subtractions, one bitshift, one bitwise "and", and one return. That's twelve instructions again.

    That said, note that if we compare your max1 with the built-in function max directly, instead of your max2 which has an extra function call, your max1 function is still a bit faster than the built-in max. This is probably because the built-in max accepts a variable number of arguments, which may involve building a tuple of arguments, and the built-in max function also has a branch to check if it was called with a single iterable argument (e.g. max([3, 1, 4, 2])), and handle that case differently; your max1 function doesn't do those things.