Search code examples
pythonnumpynumba

Delete a row in numpy.array in numba


It's my first time to post something here. I'm trying to delete a row inside a numpy array inside a numba jitclass. I wrote the following code to remove any row containing 3:

>>> a = np.array([[1,2,3,4],[5,6,7,8]])

>>> a

>>> array([[1, 2, 3, 4],
       [5, 6, 7, 8]])

>>> i = np.where(a==3)

>>> i

>>> (array([0]), array([2]))

I cannot use numpy.delete() function since it is not supported by numba and cannot assign a None type vale to the row. All I could do is to assign 0's to the row by:

>>> a[i[0]] = 0

>>> a

>>> array([[0, 0, 0, 0],
       [5, 6, 7, 8]])

But I want to remove the row completely.

Any help will be appreciated.

Thank you very much.


Solution

  • This is in fact not an easy task, since numba has the following restrictions:

    • no support for np.delete
    • no support for the axis keyword in np.all and np.any
    • no support for 2D array indexing (at least not with bool masks)
    • no or hampered direct creation of bool masks with np.zeros(shape, dtype=np.bool) or similar functions

    But still there are several approaches you can take to solve your problem. I tested a few and creating a boolean mask seems to be the fastest and cleanest way.

    @nb.njit
    def delete_workaround(arr, num):
        mask = np.zeros(arr.shape[0], dtype=np.int64) == 0
        mask[np.where(arr == num)[0]] = False
        return arr[mask]
    
    a = np.array([[1,2,3,4],[5,6,7,8]])
    
    delete_workaround(a, 3)
    

    This solution also has the huge advantage of preserving your array dimensions, even when only one row or an empty array is returned. This is important for jitclasses, since jitclasses rely heavily on fixed dimensions.

    Since you request it, I'll show you a solution which converts arrays to lists and back. Since reflected lists are not yet supported with all python methodsin numba, you'll have to use a wrapper for some parts of the function:

    @nb.njit
    def delete_lrow(arr_list, num):
        idx_list = []
        for i in range(len(arr_list)):
            if (arr_list[i] != num).all():
                idx_list.append(i)
        res_list = [arr_list[i] for i in idx_list]
        return res_list
    
    def wrap_list_del(arr, num):
        arr_list = list(arr)
        return np.array(delete_lrow(arr_list, num))
    
    arr = np.array([[1,2,3,4],[5,6,7,8],[10,11,5,13],[10,11,3,13],[10,11,99,13]])
    arr2 = np.random.randint(0, 256, 100000*4).reshape(-1, 4)
    
    %timeit delete_workaround(arr, 3)
    # 1.36 µs ± 128 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
    %timeit wrap_list_del(arr, 3)    
    # 69.3 µs ± 4.97 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    %timeit delete_workaround(arr2, 3)
    # 1.9 ms ± 68.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    %timeit wrap_list_del(arr2, 3)
    # 1.05 s ± 103 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    So sticking with arrays if you already have arrays (and even if you don't already have arrays, but your data is of consistent type) is about 50 times faster for small arrays and about 550 times faster for larger arrays. This is something to remember: Numpy arrays are there for working with numerical data! Numpy is heavily optimized for working with numerical data! There is absolutely no use in converting arrays of numerical data to another "format" if the data type (dtype) is constant and no super-special stuff requires it (I've barely ever encountered such a situation).
    And this is especially true for numba optimized code! Numba heavily relies on numpy and constant dtypes/shapes etc. Even more if you want to work with jitclasses.