Search code examples
pythondictionarynumba

How to use numba with readonly dict as input?


I tried typed.Dict.empty(types.uint8[:], types.int32), but it is not working.

import struct
import scipy as sp
import numpy as np
import numba as nb

#@nb.njit(parallel=True)
def process_uids(uid, data, XYtoBCID ):
    indices = np.where(data == uid)
    uid_row = np.zeros(data.size, dtype=bool)
    hitCnt, misCnt = 0, 0
    for idx in nb.prange(indices[0].size):
        ptX, ptY = indices[0][idx], indices[1][idx]
        #ptStr = struct.pack("@HH", ptX, ptY)
        ptS = np.array([ptX, ptY], dtype=np.uint16)
        ptS.setflags(write=0, align=0)
        ptStr = ptS.tobytes()
        if ptStr in XYtoBCID:
            bcid = XYtoBCID[ptStr]
            uid_row[bcid-1] = True
            hitCnt += 1
        else:
            misCnt += 1
    return hitCnt, misCnt, uid_row

def main():
    np.random.seed(42)
    data = np.random.choice([0, 1, 2, 3], size=(20, 20))
    print(data)
    random_values = np.random.choice(np.arange(1, 301), size=(20, 20), replace=True)
    indices_to_replace = np.random.choice(range(20 * 20), size=100, replace=False)
    random_values_flat = random_values.flatten()
    random_values_flat[indices_to_replace] = 0
    random_values = random_values_flat.reshape((20, 20))
    XYtoBCID = {}
    #XYtoBCID = nb.typed.Dict.empty(nb.types.uint8[:], nb.types.int32)
    for idx, value in enumerate(random_values.flatten(), start=1):
        posX, posY = np.unravel_index(idx-1, random_values.shape)
        #posStr = struct.pack("@HH", posX, posY)
        posStr = np.array([posX, posY], dtype=np.uint16)
        if value > 0:
            XYtoBCID[posStr.tobytes()] = value
    #print(XYtoBCID)
    values, counts = np.unique(data, return_counts=True)
    mtxCellBarcode = sp.sparse.lil_matrix((values[-1], data.size), dtype=bool)
    hitCnt,misCnt = 0,0
    for uid in values[1:]:
        hit, miss, uid_row = process_uids(uid, data, XYtoBCID)
        hitCnt += hit
        misCnt += miss
        mtxCellBarcode[uid-1, :] = uid_row
    print(f'Hit:{hitCnt}, Miss:{misCnt}.')

if __name__ == "__main__":
    main()

Solution

  • Here is modified version that compiles with (it uses nb.types.Tuple and nb.typed.Dict):

    import numba as nb
    import numpy as np
    import scipy as sp
    
    dict_key_type = nb.types.Tuple((nb.types.uint16, nb.types.uint16))
    dict_value_type = nb.types.int32
    
    
    @nb.njit(parallel=True)
    def process_uids(uid, data, XYtoBCID):
        indices = np.where(data == uid)
        uid_row = np.zeros(data.size, dtype=nb.types.uint8)
        hitCnt, misCnt = 0, 0
        for idx in nb.prange(indices[0].size):
            t = (np.uint16(indices[0][idx]), np.uint16(indices[1][idx]))
            if t in XYtoBCID:
                bcid = XYtoBCID[t]
                uid_row[bcid - 1] = True
                hitCnt += 1
            else:
                misCnt += 1
        return hitCnt, misCnt, uid_row
    
    
    
    def work_numba(data, XYtoBCID):
        values, counts = np.unique(data, return_counts=True)
        mtxCellBarcode = sp.sparse.lil_matrix((values[-1], data.size), dtype=bool)
        hitCnt, misCnt = 0, 0
        for uid in values[1:]:
            hit, miss, uid_row = process_uids(uid, data, XYtoBCID)
            hitCnt += hit
            misCnt += miss
            mtxCellBarcode[uid - 1, :] = uid_row
        return hitCnt, misCnt
    
    
    def main():
        np.random.seed(42)
        data = np.random.choice([0, 1, 2, 3], size=(20, 20))
        random_values = np.random.choice(
            np.arange(1, 301), size=(20, 20), replace=True
        ).astype(np.int32)
        indices_to_replace = np.random.choice(range(20 * 20), size=100, replace=False)
        random_values_flat = random_values.flatten()
        random_values_flat[indices_to_replace] = 0
        random_values = random_values_flat.reshape((20, 20))
    
        XYtoBCID = nb.typed.Dict.empty(dict_key_type, dict_value_type)
    
        for idx, value in enumerate(random_values.flatten(), start=1):
            posX, posY = np.unravel_index(idx - 1, random_values.shape)
            if value > 0:
                XYtoBCID[(np.uint16(posX), np.uint16(posY))] = value
    
        hitCnt, misCnt = work_numba(data, XYtoBCID)
        print(f"Hit:{hitCnt}, Miss:{misCnt}")
    
    
    if __name__ == "__main__":
        main()
    

    Running this code prints:

    Hit:224, Miss:81