Search code examples
pythonpython-3.xnumpyisinstance

np.isin - testing whether a Numpy array contains a given row considering the order


I am using the following line to find if the rows of b are in a

 a[np.all(np.isin(a[:, 0:3], b[:, 0:3]), axis=1), 3]

The arrays have more entries along axis=1, I only compare the first 3 entries and return the fourth entry (idx=3) of a.

The possible error I realized is, that the order of the entries is not considered. Therefore, the following example for a and b:

a = np.array([[...],
              [1, 2, 3, 1000],
              [2, 1, 3, 2000],
              [...]])

b = np.array([[1, 2, 3]])

would return [1000, 2000] instead of the of only [1000].

How can I consider the order of the rows as well?


Solution

  • For small b (less than 100 rows), try this instead:

    a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)]
    

    Example:

    a = np.array([[1, 0, 5, 0],
                  [1, 2, 3, 1000],
                  [2, 1, 3, 2000],
                  [0, 0, 1, 1]])
    
    b = np.array([[1, 2, 3], [0, 0, 1]])
    
    >>> a[(a[:, :3] == b[:, None]).all(axis=-1).any(axis=0), 3]
    array([1000,    1])
    

    Explanation:

    The key is to "distribute" equality tests for all rows of a (the first 3 columns) to all rows of b:

    # on the example above
    
    >>> a[:, :3] == b[:, None]
    array([[[ True, False, False],
            [ True,  True,  True],  # <-- a[1,:3] matches b[0]
            [False, False,  True],
            [False, False, False]],
    
           [[False,  True, False],
            [False, False, False],
            [False, False, False],
            [ True,  True,  True]]])  # <-- a[3, :3] matches b[1]
    

    Be warned that this can be large: the shape is (len(b), len(a), 3).

    Then the first .all(axis=-1) means that we want all entire rows to match:

    >>> (a[:, :3] == b[:, None]).all(axis=-1)
    array([[False,  True, False, False],
           [False, False, False,  True]])
    

    The final bit .any(axis=0) means: "match any row in b":

    >>> (a[:, :3] == b[:, None]).all(axis=-1).any(axis=0)
    array([False,  True, False,  True])
    

    I.e.: "a[2, :3] matches some row(s) of b as well as a[3, :3]".

    Finally, use this as a mask in a and take the column 3.

    Note on performance

    The technique above distributes equality for the product of the rows of a over the rows of b. This can be slow and use a large amount of memory if both a and b have many rows.

    As an alternative, you may use set membership in pure Python (without subsetting of columns --that can be done by the caller):

    def py_rows_in(a, b):
        z = set(map(tuple, b))
        return [row in z for row in map(tuple, a)]
    

    When b has more than 50~100 rows, then this may be faster, compared to the np version above, written here as a function:

    def np_rows_in(a, b):
        return (a == b[:, None]).all(axis=-1).any(axis=0)
    
    import perfplot
    
    fig, axes = plt.subplots(ncols=2, figsize=(16, 5))
    plt.subplots_adjust(wspace=.5)
    for ax, alen in zip(axes, [100, 10_000]):
        a = np.random.randint(0, 20, (alen, 4))
        plt.sca(ax)
        ax.set_title(f'a: {a.shape[0]:_} rows')
        perfplot.show(
            setup=lambda n: np.random.randint(0, 20, (n, 3)),
            kernels=[
                lambda b: np_rows_in(a[:, :3], b),
                lambda b: py_rows_in(a[:, :3], b),
            ],
            labels=['np_rows_in', 'py_rows_in'],
            n_range=[2 ** k for k in range(10)],
            xlabel='len(b)',
        )
    plt.show()
    

    comparative performance