Why does the result of torch.nn.functional.mse_loss(x1,x2)
result differ from the direct computation of the MSE?
My test code to reproduce:
import torch
import numpy as np
# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)
mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557
mse_direct = torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314
mse_manual = 0
for i in range(len(x1)) :
mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314
As we can see, the result from torch's mse_loss
is 0.1557
, differing from the manual MSE computation which yields 0.3314
.
In fact, the result from mse_loss
is precisely as big as the direct result times the dimension of the points (here 2).
What's up with that?
The diffrence is that torch.nn.functional.mse_loss(x1,x2) does not apply sum operation over the coordinates when calculating the squared error. However, torch.nn.functional.pairwise_distance and np.linalg.norm applies sum operation over the coordinates. You can reproduce the values of the calculated mse in the following way:
import torch
import numpy as np
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)
mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557
mse_manual = 0
x3 = torch.zeros(10,2)
for i in range(len(x1)) :
x3[i,:1] +=(torch.nn.functional.pairwise_distance(x1[i,:1],x2[i,:1],eps=0.0)**2)/len(x1)
x3[i,1:] += (torch.nn.functional.pairwise_distance(x1[i,1:],x2[i,1:],eps=0.0)**2)/len(x1)
mse_manual += x3[i]
print(mse_manual.mean()) # 0.1557
mse_manual = 0
for i in range(len(x1)) :
mse_manual += np.square(x1[i]-x2[i]) / len(x1)
print(mse_manual.mean()) # 0.1557
Or if you want to reproduce the pairwise distance function using a modified mse loss, you can do that by:
import torch
import numpy as np
# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)
mse_torch = torch.nn.functional.mse_loss(x1,x2, reduction='none')
print(mse_torch.sum(-1).mean()) # 0.3314
mse_direct =
torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314
mse_manual = 0
for i in range(len(x1)) :
mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314