I tried typed.Dict.empty(types.uint8[:], types.int32)
, but it is not working.
import struct
import scipy as sp
import numpy as np
import numba as nb
#@nb.njit(parallel=True)
def process_uids(uid, data, XYtoBCID ):
indices = np.where(data == uid)
uid_row = np.zeros(data.size, dtype=bool)
hitCnt, misCnt = 0, 0
for idx in nb.prange(indices[0].size):
ptX, ptY = indices[0][idx], indices[1][idx]
#ptStr = struct.pack("@HH", ptX, ptY)
ptS = np.array([ptX, ptY], dtype=np.uint16)
ptS.setflags(write=0, align=0)
ptStr = ptS.tobytes()
if ptStr in XYtoBCID:
bcid = XYtoBCID[ptStr]
uid_row[bcid-1] = True
hitCnt += 1
else:
misCnt += 1
return hitCnt, misCnt, uid_row
def main():
np.random.seed(42)
data = np.random.choice([0, 1, 2, 3], size=(20, 20))
print(data)
random_values = np.random.choice(np.arange(1, 301), size=(20, 20), replace=True)
indices_to_replace = np.random.choice(range(20 * 20), size=100, replace=False)
random_values_flat = random_values.flatten()
random_values_flat[indices_to_replace] = 0
random_values = random_values_flat.reshape((20, 20))
XYtoBCID = {}
#XYtoBCID = nb.typed.Dict.empty(nb.types.uint8[:], nb.types.int32)
for idx, value in enumerate(random_values.flatten(), start=1):
posX, posY = np.unravel_index(idx-1, random_values.shape)
#posStr = struct.pack("@HH", posX, posY)
posStr = np.array([posX, posY], dtype=np.uint16)
if value > 0:
XYtoBCID[posStr.tobytes()] = value
#print(XYtoBCID)
values, counts = np.unique(data, return_counts=True)
mtxCellBarcode = sp.sparse.lil_matrix((values[-1], data.size), dtype=bool)
hitCnt,misCnt = 0,0
for uid in values[1:]:
hit, miss, uid_row = process_uids(uid, data, XYtoBCID)
hitCnt += hit
misCnt += miss
mtxCellBarcode[uid-1, :] = uid_row
print(f'Hit:{hitCnt}, Miss:{misCnt}.')
if __name__ == "__main__":
main()
Here is modified version that compiles with numba (it uses nb.types.Tuple
and nb.typed.Dict
):
import numba as nb
import numpy as np
import scipy as sp
dict_key_type = nb.types.Tuple((nb.types.uint16, nb.types.uint16))
dict_value_type = nb.types.int32
@nb.njit(parallel=True)
def process_uids(uid, data, XYtoBCID):
indices = np.where(data == uid)
uid_row = np.zeros(data.size, dtype=nb.types.uint8)
hitCnt, misCnt = 0, 0
for idx in nb.prange(indices[0].size):
t = (np.uint16(indices[0][idx]), np.uint16(indices[1][idx]))
if t in XYtoBCID:
bcid = XYtoBCID[t]
uid_row[bcid - 1] = True
hitCnt += 1
else:
misCnt += 1
return hitCnt, misCnt, uid_row
def work_numba(data, XYtoBCID):
values, counts = np.unique(data, return_counts=True)
mtxCellBarcode = sp.sparse.lil_matrix((values[-1], data.size), dtype=bool)
hitCnt, misCnt = 0, 0
for uid in values[1:]:
hit, miss, uid_row = process_uids(uid, data, XYtoBCID)
hitCnt += hit
misCnt += miss
mtxCellBarcode[uid - 1, :] = uid_row
return hitCnt, misCnt
def main():
np.random.seed(42)
data = np.random.choice([0, 1, 2, 3], size=(20, 20))
random_values = np.random.choice(
np.arange(1, 301), size=(20, 20), replace=True
).astype(np.int32)
indices_to_replace = np.random.choice(range(20 * 20), size=100, replace=False)
random_values_flat = random_values.flatten()
random_values_flat[indices_to_replace] = 0
random_values = random_values_flat.reshape((20, 20))
XYtoBCID = nb.typed.Dict.empty(dict_key_type, dict_value_type)
for idx, value in enumerate(random_values.flatten(), start=1):
posX, posY = np.unravel_index(idx - 1, random_values.shape)
if value > 0:
XYtoBCID[(np.uint16(posX), np.uint16(posY))] = value
hitCnt, misCnt = work_numba(data, XYtoBCID)
print(f"Hit:{hitCnt}, Miss:{misCnt}")
if __name__ == "__main__":
main()
Running this code prints:
Hit:224, Miss:81