Search code examples

Remove row if exist duplicated value in Numpy

I'm trying to find an efficient way to remove rows of numpy array that contains duplicated elements. For example, the array below:

[[1,2,3], [1,2,2], [2,2,2]]

should keep [[1,2,3]] only.

I know pandas apply can work row-wise but that's too slow. What is the quicker alternative?



  • Using pandas nunique (not fast!):

    out = a[pd.DataFrame(a).nunique(axis=1).eq(a.shape[1])]

    Or with numpy's sort and diff to ensure all values are different in a row (quite efficient if the number of columns is reasonable):

    out = a[(np.diff(np.sort(a, axis=1))!=0).all(axis=1)]

    Or with broadcasting (memory expensive if lots of columns):

    out = a[(a[:,:,None] == a[:,None]).sum(axis=(1,2))==a.shape[1]]

    Output: array([[1, 2, 3]])

    Comparison of approaches:

    enter image description here