Search code examples
pythonnumpyperformancecython

Speed up Cython implementation of dot product multiplication


I'm trying to learn cython by trying to outperform Numpy at dot product operation np.dot(a,b). But my implementation is about 4x slower.

So, this is my hello.pyx file cython implementation:

cimport numpy as cnp
cnp.import_array()

cpdef double dot_product(double[::1] vect1, double[::1] vect2):
    cdef int size = vect1.shape[0]
    cdef double result = 0
    cdef int i = 0
    while i < size:
        result += vect1[i] * vect2[i]
        i += 1
    return result

This is my .py test file:

import timeit

setup = '''
import numpy as np
import hello

n = 10000
a = np.array([float(i) for i in range(n)])
b = np.array([i/2 for i in a])
'''
lf_code = 'res_lf = hello.dot_product(a, b)'
np_code = 'res_np = np.dot(a,b)'
n = 100
lf_time = timeit.timeit(lf_code, setup=setup, number=n) * 100
np_time = timeit.timeit(np_code, setup=setup, number=n) * 100

print(f'Lightning fast time: {lf_time}.')
print(f'Numpy time: {np_time}.')

Console output:

Lightning fast time: 0.12186000000156127.
Numpy time: 0.028800000001183435.

Command to build hello.pyx:

python setup.py build_ext --inplace

setup.py file:

from distutils.core import Extension, setup
from Cython.Build import cythonize
import numpy as np

# define an extension that will be cythonized and compiled
ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()])
setup(ext_modules=cythonize(ext))

Processor: i7-7700T @ 2.90 GHz


Solution

  • The problem mainly comes from the lack of SIMD instructions (due to both the bound-checking and the inefficient default compiler flags) compared to Numpy (which use OpenBLAS on most platforms by default).

    To fix that, you should first add the following line in the beginning of the hello.pix file:

    #cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False
    

    Then, you should use this new setup.py file:

    from distutils.core import Extension, setup
    from Cython.Build import cythonize
    import numpy as np
    
    # define an extension that will be cythonized and compiled
    ext = Extension(name="hello", sources=["hello.pyx"], include_dirs=[np.get_include()], extra_compile_args=['-O3', '-mavx', '-ffast-math'])
    setup(ext_modules=cythonize(ext))
    

    Note that the flags are dependent of the compiler. That being said, both Clang and GCC support them (and probably ICC too). -O3 tells Clang and GCC to use more aggressive optimization like the automatic vectorization of the code. -mavx tells them to use the AVX instruction set (which is only available on relatively recent x86-64 processors). -ffast-math tells them to assume that floating-point number operations are associative (which is not the case) and that you only use finite/basic numbers (no NaN, nor infinities). If the above assumption are not fulfilled, then the program can crash at runtime, so be careful about such flags.

    Note that OpenBLAS automatically selects the instruction set based on your machine and AFAIK it does not use -ffast-math but a safer (low-level) alternative.


    Results:

    Here are results on my machine:

    Before optimization:
      Lightning fast time: 0.10018469997703505.
      Numpy time: 0.024747799989199848.
    
    After (with GCC):
      Lightning fast time: 0.02865879996534204.
      Numpy time: 0.02456870001878997.
    
    After (with Clang):
      Lightning fast time: 0.01965239998753532.
      Numpy time: 0.024799799984975834.
    

    The code produced by Clang is faster than Numpy on my machine.


    Under the hood

    An analysis of the assembly code executed by the processor on my machine show that the code only use slow scalar instruction, contains unnecessary bound-checks and is mainly limited by the result += ... operation (because of a loop carried dependency).

    162e3:┌─→movsd  xmm0,QWORD PTR [rbx+rax*8]  # Load 1 item
    162e8:│  mulsd  xmm0,QWORD PTR [rsi+rax*8]  # Load 1 item
    162ed:│  addsd  xmm1,xmm0                   # Main bottleneck (accumulation)
    162f1:│  cmp    rdi,rax
    162f4:│↓ je     163f8                       # Bound checking conditional jump
    162fa:│  cmp    rdx,rax
    162fd:│↓ je     16308                       # Bound checking conditional jump
    162ff:│  add    rax,0x1
    16303:├──cmp    rcx,rax
    16306:└──jne    162e3
    

    Once optimized, the result is:

    13720:┌─→vmovupd      ymm3,YMMWORD PTR [r13+rax*1+0x0]    # Load 4 items
    13727:│  vmulpd       ymm0,ymm3,YMMWORD PTR [rcx+rax*1]   # Load 4 items
    1372c:│  add          rax,0x20
    13730:│  vaddpd       ymm1,ymm1,ymm0        # Still a bottleneck (but better)
    13734:├──cmp          rdx,rax
    13737:└──jne          13720
    

    The result += ... operation is still the bottleneck in the optimized version but this is much better since the loop work on 4 items at once. To remove the bottleneck, the loop must be partially unrolled. However, GCC (which is the default compiler on my machine) is not able to do that properly (even when ask to using -funrol-loops (due to a loop-carried dependency). This is why OpenBLAS should be a bit faster than the code produced by GCC.

    Hopefully, Clang is able to do that by default. Here is the code produced by Clang:

    59e0:┌─→vmovupd      ymm4,YMMWORD PTR [rax+rdi*8]       # load 4 items
    59e5:│  vmovupd      ymm5,YMMWORD PTR [rax+rdi*8+0x20]  # load 4 items
    59eb:│  vmovupd      ymm6,YMMWORD PTR [rax+rdi*8+0x40]  # load 4 items
    59f1:│  vmovupd      ymm7,YMMWORD PTR [rax+rdi*8+0x60]  # load 4 items
    59f7:│  vmulpd       ymm4,ymm4,YMMWORD PTR [rbx+rdi*8]
    59fc:│  vaddpd       ymm0,ymm4,ymm0
    5a00:│  vmulpd       ymm4,ymm5,YMMWORD PTR [rbx+rdi*8+0x20]
    5a06:│  vaddpd       ymm1,ymm4,ymm1
    5a0a:│  vmulpd       ymm4,ymm6,YMMWORD PTR [rbx+rdi*8+0x40]
    5a10:│  vmulpd       ymm5,ymm7,YMMWORD PTR [rbx+rdi*8+0x60]
    5a16:│  vaddpd       ymm2,ymm4,ymm2
    5a1a:│  vaddpd       ymm3,ymm5,ymm3
    5a1e:│  add          rdi,0x10
    5a22:├──cmp          rsi,rdi
    5a25:└──jne          59e0
    

    The code is not optimal (because it should unroll the loop at least 6 times due to the latency of the vaddpd instruction), but it is very good.