Search code examples
pythonperformancenumba

Why does adding a break statement significantly slow down the Numba function?


I have the following Numba function:

@numba.njit
def count_in_range(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
    return count

It counts how many values are in the range in the array.

However, I realized that I only needed to determine if they existed. So I modified it as follows:

@numba.njit
def count_in_range2(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
            break  # <---- break here
    return count

Then, this function becomes slower than before the change. Under certain conditions, it can be surprisingly more than 10 times slower.

Benchmark code:

from timeit import timeit

rng = np.random.default_rng(0)
arr = rng.random(10 * 1000 * 1000)

# To compare on even conditions, choose the condition that does not terminate early.
min_value = 0.5
max_value = min_value - 1e-10
assert not np.any(np.logical_and(min_value <= arr, arr <= max_value))

n = 100
for f in (count_in_range, count_in_range2):
    f(arr, min_value, max_value)
    elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
    print(f"{f.__name__}: {elapsed * 1000:.3f} ms")

Result:

count_in_range: 3.351 ms
count_in_range2: 42.312 ms

Further experimenting, I found that the speed varies greatly depending on the search range (i.e. min_value and max_value).

At various search ranges:

count_in_range2: 5.802 ms, range: (0.0, -1e-10)
count_in_range2: 15.408 ms, range: (0.1, 0.09999999990000001)
count_in_range2: 29.571 ms, range: (0.25, 0.2499999999)
count_in_range2: 42.514 ms, range: (0.5, 0.4999999999)
count_in_range2: 24.427 ms, range: (0.75, 0.7499999999)
count_in_range2: 12.547 ms, range: (0.9, 0.8999999999)
count_in_range2: 5.747 ms, range: (1.0, 0.9999999999)

Can someone explain to me what is going on?


I am using Numba 0.58.1 under Python 3.10.11. Confirmed on both Windows 10 and Ubuntu 22.04.


EDIT:

As an appendix to Jérôme Richard's answer:

As he pointed out in the comments, the performance difference that depends on a search range is likely due to branch prediction.

For example, when min_value is 0.1, min_value < a has a 90% chance of being true, and a < max_value has a 90% chance of being false. So mathematically it can be predicted correctly with 81% accuracy. I have no idea how the CPU does this, but I have come up with a way to check if this logic is correct.

First, by partitioning the array with values above and below the threshold, and second, by mixing it with a certain probability of error. When the array is partitioned, the number of branch prediction misses should be unaffected by the threshold. When we include errors in it, the number of misses should increase depending on the errors.

Here is the updated benchmark code:

from timeit import timeit
import numba
import numpy as np


@numba.njit
def count_in_range(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
    return count


@numba.njit
def count_in_range2(arr, min_value, max_value):
    count = 0
    for a in arr:
        if min_value < a < max_value:
            count += 1
            break  # <---- break here
    return count


def partition(arr, threshold):
    """Place the elements smaller than the threshold in the front and the elements larger than the threshold in the back."""
    less = arr[arr < threshold]
    more = arr[~(arr < threshold)]
    return np.concatenate((less, more))


def partition_with_error(arr, threshold, error_rate):
    """Same as partition, but includes errors with a certain probability."""
    less = arr[arr < threshold]
    more = arr[~(arr < threshold)]
    less_error, less_correct = np.split(less, [int(len(less) * error_rate)])
    more_error, more_correct = np.split(more, [int(len(more) * error_rate)])
    mostly_less = np.concatenate((less_correct, more_error))
    mostly_more = np.concatenate((more_correct, less_error))
    rng = np.random.default_rng(0)
    rng.shuffle(mostly_less)
    rng.shuffle(mostly_more)
    out = np.concatenate((mostly_less, mostly_more))
    assert np.array_equal(np.sort(out), np.sort(arr))
    return out


def bench(f, arr, min_value, max_value, n=10, info=""):
    f(arr, min_value, max_value)
    elapsed = timeit(lambda: f(arr, min_value, max_value), number=n) / n
    print(f"{f.__name__}: {elapsed * 1000:.3f} ms, min_value: {min_value:.1f}, {info}")


def main():
    rng = np.random.default_rng(0)
    arr = rng.random(10 * 1000 * 1000)
    thresholds = np.linspace(0, 1, 11)

    print("#", "-" * 10, "As for comparison", "-" * 10)
    bench(
        count_in_range,
        arr,
        min_value=0.5,
        max_value=0.5 - 1e-10,
    )

    print("\n#", "-" * 10, "Random Data", "-" * 10)
    for min_value in thresholds:
        bench(
            count_in_range2,
            arr,
            min_value=min_value,
            max_value=min_value - 1e-10,
        )

    print("\n#", "-" * 10, "Partitioned (Yet Still Random) Data", "-" * 10)
    for min_value in thresholds:
        bench(
            count_in_range2,
            partition(arr, threshold=min_value),
            min_value=min_value,
            max_value=min_value - 1e-10,
        )

    print("\n#", "-" * 10, "Partitioned Data with Probabilistic Errors", "-" * 10)
    for ratio in thresholds:
        bench(
            count_in_range2,
            partition_with_error(arr, threshold=0.5, error_rate=ratio),
            min_value=0.5,
            max_value=0.5 - 1e-10,
            info=f"error: {ratio:.0%}",
        )


if __name__ == "__main__":
    main()

Result:

# ---------- As for comparison ----------
count_in_range: 3.518 ms, min_value: 0.5, 

# ---------- Random Data ----------
count_in_range2: 5.958 ms, min_value: 0.0, 
count_in_range2: 15.390 ms, min_value: 0.1, 
count_in_range2: 24.715 ms, min_value: 0.2, 
count_in_range2: 33.749 ms, min_value: 0.3, 
count_in_range2: 40.007 ms, min_value: 0.4, 
count_in_range2: 42.168 ms, min_value: 0.5, 
count_in_range2: 37.427 ms, min_value: 0.6, 
count_in_range2: 28.763 ms, min_value: 0.7, 
count_in_range2: 20.089 ms, min_value: 0.8, 
count_in_range2: 12.638 ms, min_value: 0.9, 
count_in_range2: 5.876 ms, min_value: 1.0, 

# ---------- Partitioned (Yet Still Random) Data ----------
count_in_range2: 6.006 ms, min_value: 0.0, 
count_in_range2: 5.999 ms, min_value: 0.1, 
count_in_range2: 5.953 ms, min_value: 0.2, 
count_in_range2: 5.952 ms, min_value: 0.3, 
count_in_range2: 5.940 ms, min_value: 0.4, 
count_in_range2: 6.870 ms, min_value: 0.5, 
count_in_range2: 5.939 ms, min_value: 0.6, 
count_in_range2: 5.896 ms, min_value: 0.7, 
count_in_range2: 5.899 ms, min_value: 0.8, 
count_in_range2: 5.880 ms, min_value: 0.9, 
count_in_range2: 5.884 ms, min_value: 1.0, 

# ---------- Partitioned Data with Probabilistic Errors ----------
# Note that min_value = 0.5 in all the following.
count_in_range2: 5.939 ms, min_value: 0.5, error: 0%
count_in_range2: 14.015 ms, min_value: 0.5, error: 10%
count_in_range2: 22.599 ms, min_value: 0.5, error: 20%
count_in_range2: 31.763 ms, min_value: 0.5, error: 30%
count_in_range2: 39.391 ms, min_value: 0.5, error: 40%
count_in_range2: 42.227 ms, min_value: 0.5, error: 50%
count_in_range2: 38.748 ms, min_value: 0.5, error: 60%
count_in_range2: 31.758 ms, min_value: 0.5, error: 70%
count_in_range2: 22.600 ms, min_value: 0.5, error: 80%
count_in_range2: 14.090 ms, min_value: 0.5, error: 90%
count_in_range2: 6.027 ms, min_value: 0.5, error: 100%

I am satisfied with this result.


Solution

  • TL;DR: Numba uses LLVM which not able to automatically vectorize the code when there is a break. One way to fix this is to compute the operation chunk by chunk.


    Numba is based on the LLVM compiler toolchain to compile the Python code to a native one. Numlba generates an LLVM intermediate representation (IR) from the Python code and then gives that to LLVM so it can generate a fast native code. All the low-level optimizations are made by LLVM, not actually Numba itself. In this case, LLVM is not able to automatically vectorize the code when there is a break. Numba doesn't do any pattern recognition here nor run any code on the GPU (basic numba.njit code is always run on the CPU).

    Note that "vectorization" in this context means generating SIMD instructions from a scalar IR code. This word has a different meaning in the context of a Numpy Python code (which means calling native functions so to reduce the overhead but the native functions are not necessarily using SIMD instructions).


    Under the hood

    I reproduced the issue with Clang which is a C++ compiler using also the LLVM toolchain. Here is the equivalent C++ code:

    #include <cstdint>
    #include <cstdlib>
    #include <vector>
    
    int64_t count_in_range(const std::vector<double>& arr, double min_value, double max_value)
    {
        int64_t count = 0;
    
        for(int64_t i=0 ; i<arr.size() ; ++i)
        {
            double a = arr[i];
    
            if (min_value < a && a < max_value)
            {
                count += 1;
            }
        }
    
        return count;
    }
    

    This code results in the following assembly main loop:

    .LBB0_6:                                # =>This Inner Loop Header: Depth=1
            vmovupd ymm8, ymmword ptr [rcx + 8*rax]
            vmovupd ymm9, ymmword ptr [rcx + 8*rax + 32]
            vmovupd ymm10, ymmword ptr [rcx + 8*rax + 64]
            vmovupd ymm11, ymmword ptr [rcx + 8*rax + 96]
            vcmpltpd        ymm12, ymm2, ymm8
            vcmpltpd        ymm13, ymm2, ymm9
            vcmpltpd        ymm14, ymm2, ymm10
            vcmpltpd        ymm15, ymm2, ymm11
            vcmpltpd        ymm8, ymm8, ymm4
            vandpd  ymm8, ymm12, ymm8
            vpsubq  ymm3, ymm3, ymm8
            vcmpltpd        ymm8, ymm9, ymm4
            vandpd  ymm8, ymm13, ymm8
            vpsubq  ymm5, ymm5, ymm8
            vcmpltpd        ymm8, ymm10, ymm4
            vandpd  ymm8, ymm14, ymm8
            vpsubq  ymm6, ymm6, ymm8
            vcmpltpd        ymm8, ymm11, ymm4
            vandpd  ymm8, ymm15, ymm8
            vpsubq  ymm7, ymm7, ymm8
            add     rax, 16
            cmp     rsi, rax
            jne     .LBB0_6
    

    The instructions vmovupd, vcmpltpd and vandpd, etc. show that the assembly code fully uses SIMD instructions.

    If we add a break, then it is not the case any-more:

    .LBB0_4:                                # =>This Inner Loop Header: Depth=1
            vmovsd  xmm2, qword ptr [rcx + 8*rsi]   # xmm2 = mem[0],zero
            vcmpltpd        xmm3, xmm2, xmm1
            vcmpltpd        xmm2, xmm0, xmm2
            vandpd  xmm2, xmm2, xmm3
            vmovq   rdi, xmm2
            sub     rax, rdi
            test    dil, 1
            jne     .LBB0_2
            lea     rdi, [rsi + 1]
            cmp     rdx, rsi
            mov     rsi, rdi
            jne     .LBB0_4
    

    Here vmovsd moves a scalar value in the loop (and rsi is incremented of 1 per loop iteration). This later code is significantly less efficient. Indeed, it operates on only one item at a time for each iteration as opposed to 16 items for the previous code.

    We can use the compilation flag -Rpass-missed=loop-vectorize to check if the loop is indeed not vectorized. Clang explicitly reports:

    remark: loop not vectorized [-Rpass-missed=loop-vectorize

    To know the reason, we can use the flag -Rpass-analysis=loop-vectorize:

    loop not vectorized: could not determine number of loop iterations [-Rpass-analysis=loop-vectorize]

    Thus, we can conclude that LLVM optimizer does not support this pattern of code.


    Solution

    One way to avoid this issue is to operate on chunks. The computation of each chunk can be fully vectorized by Clang and you can break the condition early at the first chunk.

    Here is an untested code:

    @numba.njit
    def count_in_range_faster(arr, min_value, max_value):
        count = 0
        for i in range(0, arr.size, 16):
            if arr.size - i >= 16:
                # Optimized SIMD-friendly computation of 1 chunk of size 16
                tmp_view = arr[i:i+16]
                for j in range(0, 16):
                    if min_value < tmp_view[j] < max_value:
                        count += 1
                if count > 0:
                    return 1
            else:
                # Fallback implementation (variable-sized chunk)
                for j in range(i, arr.size):
                    if min_value < arr[j] < max_value:
                        count += 1
                if count > 0:
                    return 1
        return 0
    

    The C++ equivalent code is properly vectorized. One need to check this is also true for the Numba code with the count_in_range_faster.inspect_llvm() but the following timings show that the above implementation is faster than the two others.


    Performance results

    Here are results on a machine with a Xeon W-2255 CPU using Numba 0.56.0:

    count_in_range:          7.112 ms
    count_in_range2:        35.317 ms
    count_in_range_faster:   5.827 ms     <----------