Search code examples
pythonarraysnumpymultidimensional-arraynumpy-ndarray

Numpy comparing nested arrays


I am considering arrays of shape (n,2) for arbitrary n. I want to check whether the two-element sub-arrays match. As an example:

import numpy as np

a=np.array([[1,0],[2,0]])
b=np.array([[1,0],[2,0]])
c=np.array([[3,0],[4,0]])
d=np.array([[1,0],[5,0]])

if (b == a).any() == True:
    print('y')
    
if (c == a).any() == True:
    print('y1')
    
if (d == a).any() == True:
    print('y2')

In this code I would the first and third conditions to be evaluated as True. As the written all statements return True as the code compares the arrays elementwise.

Is there a simple way to complete this without having to loop through each array and directly compare?


Solution

  • (c == a).all(axis=1).any() checks first on axis 1 whether all elements are the same, before checking whether this is true for any of the sub-arrays.

    This only works if you want to compare sub-arrays with the same first index, i.e. a[0] with b[0] or a[1] with b[1], but not a[0] with b[1].