Search code examples
pythonnumpybitwise-operatorsbinary-operators

Fast Bitwise Sum in Python


Is there an efficient way to calculate sum of bits in each column over array in Python?

Example (Python 3.7 and Numpy 1.20.1):

  1. Create numpy array with values 0 or 1
import numpy as np

array = np.array(
    [
     [1, 0, 1],   
     [1, 1, 1], 
     [0, 0, 1],    
    ]
)
  1. Compress size by np.packbits
pack_array = np.packbits(array, axis=1)
  1. Expected result: sum of bits in each position (column) without np.unpackbits to get the same as array.sum(axis=0):
array([2, 1, 3])

I found just very slow solution:

dim = array.shape[1]
candidates = np.zeros((dim, dim)).astype(int)
np.fill_diagonal(candidates, 1)

pack_candidates = np.packbits(candidates, axis=1)

np.apply_along_axis(lambda c:np.sum((np.bitwise_and(pack_array, c) == c).all(axis=1)), 1, pack_candidates)

Solution

  • Using np.unpackbits can be problematic if the input array is big since the resulting array can be too big to fit in RAM, and even if it does fit in RAM, this would be far from being efficient since the huge array have to be written and read from the (slow) main memory. The same thing apply for CPU caches: smaller arrays can generally be computed faster. Moreover, np.unpackbits have a quite big overhead for small arrays.

    AFAIK, this is not possible to do this operation very efficiently in Numpy while using a small amount of RAM (ie. using np.unpackbits, as pointed out by @mathfux). However, Numba can be used to speed up this computation, especially for small arrays. Here is the code:

    @nb.njit('int32[::1](uint8[:,::1], int_)')
    def bitSum(packed, m):
        n = packed.shape[0]
        assert packed.shape[1]*8-7 <= m <= packed.shape[1]*8
        res = np.zeros(m, dtype=np.int32)
        for i in range(n):
            for j in range(m):
                res[j] += bool(packed[i, j//8] & (128>>(j%8)))
        return res
    

    If you want a faster implementation, you can optimize the code by working on fixed-size tiles. However, this makes the code also more complex. Here is the resulting code:

    @nb.njit('int32[::1](uint8[:,::1], int_)')
    def bitSumOpt(packed, m):
        n = packed.shape[0]
        assert packed.shape[1]*8-7 <= m <= packed.shape[1]*8
        res = np.zeros(m, dtype=np.int32)
        for i in range(0, n, 4):
            for j in range(0, m, 8):
                if i+3 < n and j+7 < m:
                    # Highly-optimized 4x8 tile computation
                    k = j//8
                    b0, b1, b2, b3 = packed[i,k], packed[i+1,k], packed[i+2,k], packed[i+3,k]
                    for j2 in range(8):
                        shift = 7 - j2
                        mask = 1 << shift
                        res[j+j2] += ((b0 & mask) + (b1 & mask) + (b2 & mask) + (b3 & mask)) >> shift
                else:
                    # Slow fallback computation
                    for i2 in range(i, min(i+4, n)):
                        for j2 in range(j, min(j+8, m)):
                            res[j2] += bool(packed[i2, j2//8] & (128>>(j2%8)))
        return res
    

    Here are performance results on my machine:

    On the example array:
    Initial code:    62.90 us   (x1)
    numpy_sumbits:    4.37 us   (x14)
    bitSumOpt:        0.84 us   (x75)
    bitSum:           0.77 us   (x82)
    
    On a random 2000x2000 array:
    Initial code:  1203.8  ms   (x1)
    numpy_sumbits:    3.9  ms   (x308)
    bitSum:           2.7  ms   (x446)
    bitSumOpt:        1.5  ms   (x802)
    

    The memory footprint of the Numba implementations is much better too (at least 8 times smaller).