It's my first time to post something here. I'm trying to delete a row inside a numpy array inside a numba jitclass. I wrote the following code to remove any row containing 3:
>>> a = np.array([[1,2,3,4],[5,6,7,8]])
>>> a
>>> array([[1, 2, 3, 4],
[5, 6, 7, 8]])
>>> i = np.where(a==3)
>>> i
>>> (array([0]), array([2]))
I cannot use numpy.delete() function since it is not supported by numba and cannot assign a None type vale to the row. All I could do is to assign 0's to the row by:
>>> a[i[0]] = 0
>>> a
>>> array([[0, 0, 0, 0],
[5, 6, 7, 8]])
But I want to remove the row completely.
Any help will be appreciated.
Thank you very much.
This is in fact not an easy task, since numba has the following restrictions:
np.delete
axis
keyword in np.all
and np.any
np.zeros(shape, dtype=np.bool)
or similar functionsBut still there are several approaches you can take to solve your problem. I tested a few and creating a boolean mask seems to be the fastest and cleanest way.
@nb.njit
def delete_workaround(arr, num):
mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
mask[np.where(arr == num)[0]] = False
return arr[mask]
a = np.array([[1,2,3,4],[5,6,7,8]])
delete_workaround(a, 3)
This solution also has the huge advantage of preserving your array dimensions, even when only one row or an empty array is returned. This is important for jitclasses, since jitclasses rely heavily on fixed dimensions.
Since you request it, I'll show you a solution which converts arrays to lists and back. Since reflected lists are not yet supported with all python methodsin numba, you'll have to use a wrapper for some parts of the function:
@nb.njit
def delete_lrow(arr_list, num):
idx_list = []
for i in range(len(arr_list)):
if (arr_list[i] != num).all():
idx_list.append(i)
res_list = [arr_list[i] for i in idx_list]
return res_list
def wrap_list_del(arr, num):
arr_list = list(arr)
return np.array(delete_lrow(arr_list, num))
arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)
%timeit delete_workaround(arr, 3)
# 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
%timeit wrap_list_del(arr, 3)
# 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit delete_workaround(arr2, 3)
# 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit wrap_list_del(arr2, 3)
# 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So sticking with arrays if you already have arrays (and even if you don't already have arrays, but your data is of consistent type) is about 50 times faster for small arrays and about 550 times faster for larger arrays.
This is something to remember: Numpy arrays are there for working with numerical data! Numpy is heavily optimized for working with numerical data! There is absolutely no use in converting arrays of numerical data to another "format" if the data type (dtype
) is constant and no super-special stuff requires it (I've barely ever encountered such a situation).
And this is especially true for numba optimized code! Numba heavily relies on numpy and constant dtypes
/shapes etc. Even more if you want to work with jitclasses.