I want to njit
a function with numba, where pnt_group_ids_
can be in two types, np.int64
or np.int64[::1]
. :
import numpy as np
import numba as nb
sorted_fb_i = np.array([1, 3, 4, 2, 5], np.int64)
fb_groups_ids = nb.typed.List([np.array([4, 2], np.int64), np.array([1, 3, 5], np.int64)])
moved_fb_group_ids = nb.typed.List.empty_list(nb.types.Array(dtype=nb.int64, ndim=1, layout="C"))
ind = 0
@nb.njit
def points_group_ids(sorted_fb_i, fb_groups_ids, moved_fb_group_ids, ind):
pnt_group_ids_ = sorted_fb_i[ind]
for i in range(len(fb_groups_ids)):
if sorted_fb_i[ind] in fb_groups_ids[i]:
pnt_group_ids_ = fb_groups_ids[i]
moved_fb_group_ids.append(fb_groups_ids.pop(i))
break
return pnt_group_ids_, fb_groups_ids, moved_fb_group_ids
which will get error:
Cannot unify array(int64, 1d, C) and int64 for 'pnt_group_ids_.2'
Is there any way to write signature for that which can handle the both types, something like:
((int64, int64[::1]), ListType(int64[::1]), ListType(int64[::1]))(int64[::1], ListType(int64[::1]), ListType(int64[::1]), int64)
If it could not to be handled by signatures, the related line can be substituted by:
pnt_group_ids_ = np.array([sorted_fb_i[ind]], np.int64)
Which will work. But, how to signature this when we have multiple inputs and multiple outputs? Now, it will get the following error when we use such the above signature with just one type:
TypeError: 'tuple' object is not callable
This function will be called in a loop, so moved_fb_group_ids
, which was an empty numba list and should have been typed otherwise it get error, will be filled and fb_groups_ids
becomes empty; Does emptyness of fb_groups_ids
will cause the code to get error?
The main goal of this question was about how to write signatures (for both inputs and outputs besides each other) for this function (I know that it is recommended to let numba find them), when we have multiple input and multiple output (preferring signature that can handle both types without changing the code, if it be possible).
as a single number can be an array with 1 element, a simple solution is to just convert your single number to an array.
pnt_group_ids_ = sorted_fb_i[ind:ind+1]
@nb.njit("Tuple((int64[::1],ListType(int64[::1]),ListType(int64[::1])))(int64[::1], ListType(int64[::1]), ListType(int64[::1]), int64)")
def points_group_ids(sorted_fb_i, fb_groups_ids, moved_fb_group_ids, ind):
pnt_group_ids_ = sorted_fb_i[ind:ind+1]
for i in range(len(fb_groups_ids)):
if sorted_fb_i[ind] in fb_groups_ids[i]:
pnt_group_ids_ = fb_groups_ids[i]
moved_fb_group_ids.append(fb_groups_ids.pop(i))
break
return pnt_group_ids_, fb_groups_ids, moved_fb_group_ids
and it works .... without any context of what this is for ...