Search code examples
pythonnumpymean-square-error

How can I find the value with the minimum MSE with a numpy array?


My possible values are:

0: [0 0 0 0]
1: [1 0 0 0]
2: [1 1 0 0]
3: [1 1 1 0]
4: [1 1 1 1]

I have some values:

[[0.9539342  0.84090066 0.46451256 0.09715253],
 [0.9923432  0.01231235 0.19491441 0.09715253]
 ....

I want to figure out which of my possible values this is the closest to my new values. Ideally I want to avoid doing a for loop and wonder if there's some sort of vectorized way to search for the minimum mean squared error?

I want it to return an array that looks like: [2, 1 ....


Solution

  • You can use np.argmin to get the lowest index of the rmse value which can be calculated using np.linalg.norm

    import numpy as np
    a = np.array([[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 0, 0],[1, 1, 1, 0], [1, 1, 1, 1]])
    b = np.array([0.9539342, 0.84090066, 0.46451256, 0.09715253])
    np.argmin(np.linalg.norm(a-b, axis=1))
    #outputs 2 which corresponds to the value [1, 1, 0, 0]
    

    As mentioned in the edit, b can have multiple rows. The op wants to avoid for loop, but I can't seem to find a way to avoid the for loop. Here is a list comp way, but there could be a better way

    [np.argmin(np.linalg.norm(a-i, axis=1)) for i in b] 
    #Outputs [2, 1]