Search code examples
pythonnumpynumbabit-packing

How do I optimise numpy.packbits with numba?


I'm trying to optimise numpy.packbits:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def _numba_pack(arr, div, su):
    for i in prange(div):
        s = 0
        for j in range(i*8, i*8+8):
            s = 2*s + arr[j]
        su[i] = s
        
def numba_packbits(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

>>> X = np.random.randint(2, size=99, dtype=bool)
>>> print(numba_packbits(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

It appears 2 - 2.5 times slower than np.packbits(X). How is this implemeted in numpy internally? Could this be improved in numba?

I work on numpy == 1.21.2 and numba == 0.53.1 installed via conda install. My platform is:

enter image description here

Results:

import benchit
from numpy import packbits
%matplotlib inline
benchit.setparams(rep=5)

sizes = [100000, 300000, 1000000, 3000000, 10000000, 30000000]
N = sizes[-1]
arr = np.random.randint(2, size=N, dtype=bool)
fns = [numba_packbits, packbits]

in_ = {s/1000000: (arr[:s], ) for s in sizes}
t = benchit.timings(fns, in_, multivar=True, input_name='Millions of bits')
t.plot(logx=True, figsize=(12, 6), fontsize=14)

enter image description here

Update

With the response of Jérôme:

@njit('void(bool_[::1], uint8[::1], int_)', inline='never')
def _numba_pack_x64_byJérôme(arr, su, pos):
    for i in range(64):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
       
@njit(parallel=True)
def _numba_pack_byJérôme(arr, div, su):
    for i in prange(div//64):
        _numba_pack_x64_byJérôme(arr[i*8:(i+64)*8], su[i:i+64], i)
    for i in range(div//64*64, div):
        j = i * 8
        su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
        
def numba_packbits_byJérôme(arr):
    div, mod = np.divmod(arr.size, 8)
    su = np.zeros(div + (mod>0), dtype=np.uint8)
    _numba_pack_byJérôme(arr[:div*8], div, su)
    if mod > 0:
        su[-1] = sum(x*y for x,y in zip(arr[div*8:], (128, 64, 32, 16, 8, 4, 2, 1)))
    return su

Usage:

>>> print(numba_packbits_byJérôme(X))
[ 75  24  79  61 209 189 203 187  47 226 170  61   0]

Results:

enter image description here


Solution

  • There are several issue with the Numba implementation. One of them is that parallel loops breaks the constant propagation optimization in LLVM-Lite (the JIT-compiler used by Numba). This cause critical information like array strides not to be propagated resulting in a slow scalar implementation instead of an SIMD one, and additional unneded instructions so to compute the offsets. Such issue can also be seen in C code. Numpy added specific macros so help compilers to automatically vectorize the code (ie. use SIMD instructions) when the stride of the working dimension is actually 1.

    A solution to overcome the constant propagation issue is to call another Numba function. This function must not be inlined. The signature should be manually provided so the compiler can know the stride of the array is 1 at compilation time and generate a faster code. Finally, the function should work on fixed-size chunks because function calls are expensive and the compiler can vectorize the code. Unrolling the loop with shifts also produce a faster code (although it is uglier). Here is an example:

    @njit('void(bool_[::1], uint8[::1], int_)', inline='never')
    def _numba_pack_x64(arr, su, pos):
        for i in range(64):
            j = i * 8
            su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
    
    @njit('void(bool_[::1], int_, uint8[::1])', parallel=True)
    def _numba_pack(arr, div, su):
        for i in prange(div//64):
            _numba_pack_x64(arr[i*8:(i+64)*8], su[i:i+64], i)
        for i in range(div//64*64, div):
            j = i * 8
            su[i] = (arr[j]<<7)|(arr[j+1]<<6)|(arr[j+2]<<5)|(arr[j+3]<<4)|(arr[j+4]<<3)|(arr[j+5]<<2)|(arr[j+6]<<1)|arr[j+7]
    

    Benchmark

    Here are performance results on my 6-core machine (i5-9600KF) with a billion random items as input:

    Initial Numba (seq):    189 ms  (x0.7)
    Initial Numba (par):    141 ms  (x1.0)
    Numpy (seq):             98 ms  (x1.4)
    Optimized Numba (par):   35 ms  (x4.0)
    Theoretical optimal:     27 ms  (x5.2)  [fully memory-bound case]
    

    This new implementation is 4 times faster than the initial parallel implementation and about 3 times faster than Numpy.


    Delving into the generated assembly code

    When parallel=False is set and prange is replaced with range, the following assembly code is generated on my Intel processor supporting AVX-2:

    .LBB0_7:
        vmovdqu 112(%rdx,%rax,8), %xmm1
        vmovdqa 384(%rsp), %xmm3
        vpshufb %xmm3, %xmm1, %xmm0
        vmovdqu 96(%rdx,%rax,8), %xmm2
        vpshufb %xmm3, %xmm2, %xmm3
        vpunpcklwd  %xmm0, %xmm3, %xmm3
        vmovdqu 80(%rdx,%rax,8), %xmm15
        vmovdqa 368(%rsp), %xmm5
        vpshufb %xmm5, %xmm15, %xmm4
        vmovdqu 64(%rdx,%rax,8), %xmm0
        [...] <------------------------------  ~180 other instructions discarded
        vpcmpeqb    %xmm3, %xmm11, %xmm2
        vpandn  %xmm8, %xmm2, %xmm2
        vpor    %xmm2, %xmm1, %xmm1
        vpcmpeqb    %xmm3, %xmm0, %xmm0
        vpaddb  %xmm0, %xmm1, %xmm0
        vpsubb  %xmm4, %xmm0, %xmm0
        vmovdqu %xmm0, (%r11,%rax)
        addq    $16, %rax
        cmpq    %rax, %rsi
        jne .LBB0_7
    

    The code is not very good because it uses many unneeded instructions (like SIMD comparison instructions probably due to implicit casts from boolean types), a lot of register are temporary stored (register spilling) and also it uses 128-bit AVX vectors instead of 256-bit AVX ones supported on my machine. That being said, the code is vectorized and each loop iteration writes on 16-bytes at once without any conditional branches (except the one of the loop) so the resulting performance is not so bad.

    In fact, the Numpy code is much smaller and more efficient. This is why it is about 2 times faster than the sequential Numba code on my machine with big inputs. Here is the hot assembly loop:

    4e8:
        mov      (%rdx,%rax,8),%rcx
        bswap    %rcx
        mov      %rcx,0x20(%rsp)
        mov      0x8(%rdx,%rax,8),%rcx
        add      $0x2,%rax
        movq     0x20(%rsp),%xmm0
        bswap    %rcx
        mov      %rcx,0x20(%rsp)
        movhps   0x20(%rsp),%xmm0
        pcmpeqb  %xmm1,%xmm0
        pcmpeqb  %xmm1,%xmm0
        pmovmskb %xmm0,%ecx
        mov      %cl,(%rsi)
        movzbl   %ch,%ecx
        mov      %cl,(%rsi,%r13,1)
        add      %r9,%rsi
        cmp      %rax,%r8
        jg       4e8
    

    It read values by chunks of 8-bytes and compute them partially using 128-bit SSE instructions. 2 bytes are written at per iterations. That being said, it is not optimal either because 256-bit SIMD instructions are not used and I think the code can be optimized further.

    When the initial parallel code is used, here is the assembly code of the hot loop:

    .LBB3_4:
         movq %r9, %rax
         leaq (%r10,%r14), %r9
         movq %r15, %rsi
         sarq $63, %rsi
         andq %rdx, %rsi
         addq %r11, %rsi
         cmpb $0, (%r14,%rsi)
         setne     %cl
         addb %cl, %cl
         [...] <---------------  56 instructions (with few 256-bit AVX ones)
         orb  %bl, %cl
         orb  %al, %cl
         orb  %dl, %cl
         movq %rbp, %rdx
         movb %cl, (%r8,%r15)
         incq %r15
         decq %rdi
         addq $8, %r14
         cmpq $1, %rdi
         jg   .LBB3_4
    

    The above code is mainly not vectorized and is quite inefficient. It use a lot of instructions (including quite slow ones like setne/cmovlq/cmpb to do many conditional stores) for each iteration just to write 1-byte at a time. Numpy execute about 8 times less instructions for the same amount of written bytes. The inefficiency of this code is mitigated by the use of multiple threads. In the end, the parallel version can be a bit faster on machines with many cores (eg. >= 6).

    The improved implementation provided in the beginning of this answer generate a code similar to the above sequential implementation but using multiple thread (so still far from being optimal, but much batter).