Search code examples
pythonarrayslistperformancenumba

Why is element access for typed lists so much slower than for arrays with Numba?


I wonder why accessing elements is so much slower for typed Lists than for NumPy arrays when using Numba. I have this minimal example shown below. I generate indices randomly to prevent any compiler optimization behind the scenes. It seems that after correcting for the time it takes to generate all these random numbers, each element access occurs almost instant (<1ns) in the case of NumPy arrays, whereas for the typed list, each access takes up to 100 ns. Maybe one should expect that typed Lists are a bit slower, but this seems too much of a difference to me and could significantly slow down code if one needs to access list elements a lot. Unfortunately, I am not a computer science expert, so I probably lack some basic background knowledge on how the access operation works on these two different data structures. So, do you have any idea why there is such a significant difference in access speed?


import numpy as np
import numba as nb

@nb.njit
def only_rand(N):
    
    for _ in range(10000):
        
        i = np.random.randint(N)
        j = np.random.randint(N)

@nb.njit
def foo(pos, N):
    
    for _ in range(10000):
        
        i = np.random.randint(N)
        j = np.random.randint(N)
        
        dx = pos[i][0] - pos[j][0]
        dy = pos[i][1] - pos[j][1]
        dz = pos[i][2] - pos[j][2]
        

N = 100

Array = np.random.rand(N,3)
List = nb.typed.List(Array)

print('Random number generation:')
%timeit only_rand(N)
print('Numpy Array:')
%timeit foo(Array, N)
print('Typed List:')
%timeit foo(List, N)

Out:

Random number generation:
133 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Numpy Array:
132 µs ± 881 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Typed List:
947 µs ± 74 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Solution

  • One problem is that the benchmark is flawed. Indeed, the Numba JIT compiler can (partially) see that your computation is mostly useless since it mostly does not have a computational visible impact: dx, dy and dz are not read and thus their computation (eg. pos[i][0] - pos[j][0]) can be simply ignored. The same seems to apply for i and j at first glance but this is not the case: np.random.randint modify an internal seed causing a side effects. This side effect force the compiler to still compute partially the loop.

    However, besides the above point, the list-based implementation is indeed slower once the benchmark fixed. It comes from the reference counting of the temporary lists. and the fact that the assembly code is less well optimized by the JIT (lists tends to generate a more complex code harder to optimize).


    In-depth analysis:

    To see that the JIT optimize the code, you can just increase a lot the value of N. Here are timings on my machine:

    With N=100:
      Random number generation:
      109 µs ± 13.1 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Numpy Array:
      113 µs ± 22.7 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Typed List:
      806 µs ± 197 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)
    
    With N=1_000_000:
      Random number generation:
      64.7 µs ± 13.3 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Numpy Array:
      68.6 µs ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Typed List:
      804 µs ± 215 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)
    
    With N=10_000_000:
      Random number generation:
      185 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Numpy Array:
      190 µs ± 23.2 µs per loop (mean ± std. dev. of 7 runs, 2000 loops each)
      Typed List:
      839 µs ± 200 µs per loop (mean ± std. dev. of 7 runs, 250 loops each)
    

    Note that the time are not very dependant of N.

    The assembly code of the two implementation is pretty huge but one can see that the main loop are similar and contains in both case calls to numba_rnd_shuffle which are not optimized (due to the side effect of np.random.randint). Here is an example:

    .LBB0_20: <----------\
            cmpl    $624, %eax
            jae     .LBB0_21
    .LBB0_22: <----------\
            movl    %eax, %ecx
            movl    4(%rsi,%rcx,4), %ebp
            leal    1(%rax), %ecx
            movl    %ecx, (%rsi)
            movl    %ebp, %edx
            shrl    $11, %edx
            xorl    %ebp, %edx
            movl    %edx, %ebp
            shll    $7, %ebp
            andl    $-1658038656, %ebp
            xorl    %edx, %ebp
            movl    %ebp, %edx
            shll    $15, %edx
            andl    $-272236544, %edx
            xorl    %ebp, %edx
            movl    %edx, %ebp
            shrl    $18, %ebp
            xorl    %edx, %ebp
            andl    %edi, %ebp
            cmpl    $623, %eax
            jae     .LBB0_23
    .LBB0_24: <----------\
            movl    %ecx, %eax
            movl    4(%rsi,%rax,4), %eax
            incl    %ecx
            movl    %ecx, (%rsi)
            movl    %eax, %edx
            shrl    $11, %edx
            xorl    %eax, %edx
            movl    %edx, %eax
            shll    $7, %eax
            andl    $-1658038656, %eax
            xorl    %edx, %eax
            movl    %eax, %edx
            shll    $15, %edx
            andl    $-272236544, %edx
            xorl    %eax, %edx
            movl    %edx, %eax
            shrl    $18, %eax
            xorl    %edx, %eax
            shlq    $32, %rbp
            orq     %rax, %rbp
            movl    %ecx, %eax
            cmpq    %r14, %rbp
            jge     .LBB0_20 ---------->
            jmp     .LBB0_12
    .LBB0_21:
            movq    %rsi, %rcx
            movabsq $numba_rnd_shuffle, %rax
            callq   *%rax
            movl    $0, (%rsi)
            xorl    %eax, %eax
            jmp     .LBB0_22 ---------->
    .LBB0_23:
            movq    %rsi, %rcx
            movabsq $numba_rnd_shuffle, %rax
            callq   *%rax
            movl    $0, (%rsi)
            xorl    %ecx, %ecx
            jmp     .LBB0_24 ---------->
            .p2align        4, 0x90
    

    The thing is that at the end of each iteration, the following assembly code is repeated 6 times:

            movabsq $numba_list_size_address, %rdi
    
            movq    %r13, %rcx
            movabsq $NRT_incref, %rax
            callq   *%rax                  # NRT_incref(ptrVar);
    
            movq    %r15, %rcx
            callq   *%rdi                  # tmp1 = numba_list_size_address(listVar);
    
            movq    %rbp, %r12
            sarq    $63, %r12
            movq    (%rax), %rbx           # tmp2 = fancy_operation(*tmp1)
            andq    %r12, %rbx
            addq    %rbp, %rbx
            js      .LBB0_34               # Conditional goto to the end (overflow check?)
    
            movq    %r15, %rcx
            callq   *%rdi                  # tmp3 = numba_list_size_address(listVar);
    
            movq    %r15, %rdi
            movq    (%rax), %r15
            movq    %r13, %rcx
            movabsq $NRT_decref, %rax
            callq   *%rax                  # NRT_decref(ptrVar);
    
            cmpq    %r15, %rbx             # if(*tmp3 >= tmp2)
            jge     .LBB0_33               #     goto end;
    
            movq    %rdi, %rcx
            movabsq $numba_list_base_ptr, %rax
            callq   *%rax                  # numba_list_base_ptr(listVar);
    

    One can see the reference counting calls as well as list-related functions. This portion of the assembly code comes from the expressions pos[i] and pos[j]. The reference counting of the list objects are not optimized by the JIT. This also appear to be the case for related checks.

    I guess this is because Numba function calls cannot be optimized in this case or is assumed not to be enough expensive by the JIT. The code of the list-related functions can be found here. I find it weird that the JIT do not optimize the list-related function calls since they are mark as alwaysinline and readonly by Numba... Anyway, I think it is a missed optimization and it can be improved.

    I submitted an issue to Numba developers here.