Search code examples
numpyscikit-learnnanidiomsmean-square-error

Clean np array of NaN while deleting entries in other array accordingly


I have two numpy arrays, one of which contains about 1% NaNs.

a = np.array([-2,5,nan,6])
b = np.array([2,3,1,0])

I'd like to compute the mean squared error of a and b using sklearn's mean_squared_error.

So my question is, what's the pythonic way of removing all NaNs from a while at the same time deleting all corresponding entries from b as efficiently as possible?


Solution

  • You can simply use vanilla NumPy's np.nanmean for this purpose:

    In [136]: np.nanmean((a-b)**2)
    Out[136]: 18.666666666666668
    

    If this didn't exist, or you really wanted to use the sklearn method, you could create a mask to index the NaNs:

    In [148]: mask = ~np.isnan(a)
    
    In [149]: mean_squared_error(a[mask], b[mask])
    Out[149]: 18.666666666666668