Search code examples
pythonpytorchloss-functionmse

PyTorch MSE Loss differs from direct calculation by factor of 2


Why does the result of torch.nn.functional.mse_loss(x1,x2) result differ from the direct computation of the MSE?

My test code to reproduce:

import torch
import numpy as np

# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557

mse_direct = torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314

mse_manual = 0
for i in range(len(x1)) :
    mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314 

As we can see, the result from torch's mse_loss is 0.1557, differing from the manual MSE computation which yields 0.3314.

In fact, the result from mse_loss is precisely as big as the direct result times the dimension of the points (here 2).

What's up with that?


Solution

  • The diffrence is that torch.nn.functional.mse_loss(x1,x2) does not apply sum operation over the coordinates when calculating the squared error. However, torch.nn.functional.pairwise_distance and np.linalg.norm applies sum operation over the coordinates. You can reproduce the values of the calculated mse in the following way:

    import torch
    import numpy as np
    
    x1 = torch.rand(10,2)
    x2 = torch.rand(10,2)
    
    mse_torch = torch.nn.functional.mse_loss(x1,x2)
    print(mse_torch) # 0.1557
    
    mse_manual = 0
    x3 = torch.zeros(10,2)
    for i in range(len(x1)) :
       x3[i,:1] +=(torch.nn.functional.pairwise_distance(x1[i,:1],x2[i,:1],eps=0.0)**2)/len(x1)
       x3[i,1:] += (torch.nn.functional.pairwise_distance(x1[i,1:],x2[i,1:],eps=0.0)**2)/len(x1)
       mse_manual += x3[i]
    print(mse_manual.mean()) # 0.1557
    
    mse_manual = 0
    for i in range(len(x1)) :
       mse_manual += np.square(x1[i]-x2[i]) / len(x1)
    print(mse_manual.mean()) # 0.1557 
    

    Or if you want to reproduce the pairwise distance function using a modified mse loss, you can do that by:

    import torch
    import numpy as np
    # Think of x1 as predicted 2D coordinates and x2 of ground truth
    x1 = torch.rand(10,2)
    x2 = torch.rand(10,2)
    
    mse_torch = torch.nn.functional.mse_loss(x1,x2, reduction='none')
    print(mse_torch.sum(-1).mean()) # 0.3314
    
    mse_direct = 
    torch.nn.functional.pairwise_distance(x1,x2).square().mean()
    print(mse_direct) # 0.3314
    
    mse_manual = 0
    for i in range(len(x1)) :
        mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
    print(mse_manual) # 0.3314