Search code examples
pythonnumpyassertarray-broadcasting

How to check numpy arrays are equal


I was doing some exercises in numpy, in particular for broadcasting, but I'm stuck..
Can someone please explain how assert should be used?

def fill_0(n):
    return np.zeros(n) -1

def fill_1(n):
    return np.zeros(n) *(-1)

def fill_2(n):
    return - np.ones(n)

def fill_3(n):
    return - np.ones(n) -2

assert fill_0(4) == fill_1(4) == fill_2(4) == fill_3(4)

Solution

  • I'd do it this way:

    np.testing.assert_array_equal(fill_0(4), fill_1(4))
    np.testing.assert_array_equal(fill_0(4), fill_2(4))
    np.testing.assert_array_equal(fill_0(4), fill_3(4))
    

    This makes it a lot more clear where the failure is (because each pair is a separate line), and it works even if there are NaNs in the data, whereas regular equality comparison would fail (because NaN==NaN is False).