0

Why does the result of torch.nn.functional.mse_loss(x1,x2) result differ from the direct computation of the MSE?

My test code to reproduce:

import torch
import numpy as np

# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557

mse_direct = torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314

mse_manual = 0
for i in range(len(x1)) :
    mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314 

As we can see, the result from torch's mse_loss is 0.1557, differing from the manual MSE computation which yields 0.3314.

In fact, the result from mse_loss is precisely as big as the direct result times the dimension of the points (here 2).

What's up with that?

csstudent1418
  • 289
  • 5
  • 14

1 Answers1

1

The diffrence is that torch.nn.functional.mse_loss(x1,x2) does not apply sum operation over the coordinates when calculating the squared error. However, torch.nn.functional.pairwise_distance and np.linalg.norm applies sum operation over the coordinates. You can reproduce the values of the calculated mse in the following way:

import torch
import numpy as np

x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2)
print(mse_torch) # 0.1557

mse_manual = 0
x3 = torch.zeros(10,2)
for i in range(len(x1)) :
   x3[i,:1] +=(torch.nn.functional.pairwise_distance(x1[i,:1],x2[i,:1],eps=0.0)**2)/len(x1)
   x3[i,1:] += (torch.nn.functional.pairwise_distance(x1[i,1:],x2[i,1:],eps=0.0)**2)/len(x1)
   mse_manual += x3[i]
print(mse_manual.mean()) # 0.1557

mse_manual = 0
for i in range(len(x1)) :
   mse_manual += np.square(x1[i]-x2[i]) / len(x1)
print(mse_manual.mean()) # 0.1557 

Or if you want to reproduce the pairwise distance function using a modified mse loss, you can do that by:

import torch
import numpy as np
# Think of x1 as predicted 2D coordinates and x2 of ground truth
x1 = torch.rand(10,2)
x2 = torch.rand(10,2)

mse_torch = torch.nn.functional.mse_loss(x1,x2, reduction='none')
print(mse_torch.sum(-1).mean()) # 0.3314

mse_direct = 
torch.nn.functional.pairwise_distance(x1,x2).square().mean()
print(mse_direct) # 0.3314

mse_manual = 0
for i in range(len(x1)) :
    mse_manual += np.square(np.linalg.norm(x1[i]-x2[i])) / len(x1)
print(mse_manual) # 0.3314 
A.Mounir
  • 528
  • 4
  • 8
  • However, if you use `torch.nn.functional.mse_loss(x1,x2,reduction='sum')` then the sum operator *is* applied yet the result differs even more wildly (it is way too big). Unless I am missing something. – csstudent1418 Aug 25 '23 at 21:02
  • if you use torch.nn.functional.mse_loss(x1,x2,reduction='sum') means that a sum operation is applied to the output. instead of the mean operation. Specifically, the output of the operation (x1-x2)**2 is not divided by len(x1), and mse_manual.sum is used instead of sum_manual.mean(). – A.Mounir Aug 25 '23 at 21:18
  • btw torch.nn.functional.pairwise_distance(x1,x2).square().sum() will give the same results as torch.nn.functional.mse_loss(x1,x2,reduction='sum'). because both reduce the result using a sum operation. – A.Mounir Aug 25 '23 at 21:27
  • But the result seems incorrect: All of your computations refer to the *element-wise* mean squared error (i.e. each axis/dimension error is considered individually), but since we are talking about actual coordinates I am interested in the mean squared L2-error of the points. I have found that `torch.nn.functional.pairwise_distance(x1,x2).square().mean()` gives the intended result. – csstudent1418 Aug 25 '23 at 21:32
  • 1
    if torch.nn.functional.pairwise_distance(x1,x2).square().mean() gives the intended result. then, you can use mse_torch = torch.nn.functional.mse_loss(x1,x2, reduction='none'), then, mse_torch.sum(-1).mean(). – A.Mounir Aug 25 '23 at 21:59
  • That works too, nice! Thanks, now I have multiple choices. – csstudent1418 Aug 26 '23 at 07:53