Search code examples
indexingcythonbitarray

Direct access to bitarray from cython


I can access bitarray bits with slice syntax..

b = bitarray(10)
b[5]

How would I access an element directly ?

Similar to the way I can directly access array elements:

ary.data.as_ints[5]

instead of :

ary[5]

I'm asking because when I tried this for array in some scenarios I got 20-30 fold speedup.


I found what I need to get access to, but don't know how !

bitarray.h

look at getbit() and setbit().

How can I access them from Cython ?


current speeds

Shape: (10000, 10000)
VSize: 100.00Mil
Mem: 12207.03kb, 11.92mb
                
----------------------
sa[5,5]=1
108 ns +- 0.451 ns per loop (mean +- std. dev. of 7 runs, 10000000 loops each)
sa[5,5]
146 ns +- 37.1 ns per loop (mean +- std. dev. of 7 runs, 10000000 loops each)
sa[100:120,100:120]
34.8 µs +- 7.39 µs per loop (mean +- std. dev. of 7 runs, 10000 loops each)
sa[:100,:100]
614 µs +- 135 µs per loop (mean +- std. dev. of 7 runs, 1000 loops each)
sa[[0,1,2],[0,1,2]]
1.11 µs +- 301 ns per loop (mean +- std. dev. of 7 runs, 1000000 loops each)
sa.sum()
6.74 ms +- 1.82 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.sum(axis=0)
9.92 ms +- 2.49 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.sum(axis=1)
646 ms +- 42.4 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)
sa.mean()
5.17 ms +- 160 µs per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.mean(axis=0)
12.8 ms +- 2.5 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa.mean(axis=1)
730 ms +- 25.1 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],:],
6.87 ms +- 1.2 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa[:,[1417, 157, 9793, 1300, 2339, 2439, 2925, 3980, 4550, 5100]],
9.88 ms +- 1.56 ms per loop (mean +- std. dev. of 7 runs, 100 loops each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],[1417, 157, 9793, 1300, 2339, 2439, 2925, 3980, 4550, 5100]],
6.59 µs +- 1.78 µs per loop (mean +- std. dev. of 7 runs, 100000 loops each)
sa[[9269, 5484, 2001, 8881, 30, 9567, 7654, 3034, 4901, 552],:].sum(axis=1),
466 ms +- 121 ms per loop (mean +- std. dev. of 7 runs, 1 loop each)

Solution

  • I'd recommend using typed memoryviews (which lets you access chunks of 8 bits) and then using bitwise-and operations to access those bits. That's definitely the easiest and most "native" way to Cython.

    cimport cython
    
    @cython.boundscheck(False)
    @cython.wraparound(False)
    def sum_bits1(ba):
        cdef unsigned char[::1] ba_view = ba
        cdef int count = 0
        cdef Py_ssize_t n
        cdef unsigned char val
        for n in range(len(ba)):
            idx = n//8
            subidx = 1 << (n % 8)
            val = ba_view[idx] & subidx
            if val:
                count += 1
        return count
    

    If you want to use it getbit and setbit functions defined in "bitarray.h" then you just define them as cdef extern functions. You need to find the path to "bitarray.h". It's probably in your local pip install directory somewhere. I've put the full path in the file but a better solution would be to specify an include path in setup.py.

    cdef extern from "<path to home>/.local/lib/python3.8/site-packages/bitarray/bitarray.h":
        ctypedef struct bitarrayobject:
            pass # we don't need to know the details
        
        ctypedef class bitarray.bitarray [object bitarrayobject]:
            pass
        
        int getbit(bitarray, int)
            
    def sum_bits2(bitarray ba):
        cdef int count = 0
        cdef Py_ssize_t n
        for n in range(len(ba)):
            if getbit(ba, n):
                count += 1
        return count
    

    To test it (and compare against a simple Python only version):

    def sum_bits_naive(ba):
        count = 0
        for n in range(len(ba)):
            if ba[n]:
                count += 1
        return count
    
    def test_funcs():
        from bitarray import bitarray
        
        ba = bitarray("110010"*10000)
        print(sum_bits1(ba), sum_bits2(ba), sum_bits_naive(ba))
        from timeit import timeit
        globs = dict(globals())
        globs.update(locals())
        print(timeit("sum_bits1(ba)", globals=globs, number=1000))
        print(timeit("sum_bits2(ba)", globals=globs, number=1000))
        print(timeit("sum_bits_naive(ba)", globals=globs, number=1000))
    

    gives

    (30000, 30000, 30000)
    0.069798200041987
    0.09307677199831232
    1.3518586970167235
    

    i.e. the memoryview version is the best.