0

Here's the screenshot of a YouTube video implementing the Loss function from the YOLOv1 original research paper. enter image description here

What I don't understand is the need for torch.Flatten() while passing the input to self.mse(), which, in fact, is nn.MSELoss()

The video just mentions the reason as nn.MSELoss() expects the input in the shape (a,b), which I specifically don't understand how or why?

Video link just in case. [For reference, N is the batch size, S is the grid size (split size)]

Aarush Aggarwal
  • 91
  • 1
  • 3
  • 5

2 Answers2

0

It helps to go back to the definitions. What is MSE? What is it computing?

MSE = mean squared error.

This will be rough pythonic pseudo code to illustrate.

total = 0
for (x,y) in (data,labels):
   total += (x-y)**2
return total / len(labels)  # the average squared difference

For each pair of entries it subtracts two numbers together and returns the average (or mean) after all of the subtractions.
To rephrase the question how would you interpret MSE without flattening? MSE as described and implemented doesn't mean anything for higher dimensions. You can use other loss functions if you want to work with the outputs being matrices such as norms of the output matrices.

Anyways hope that answers your question as to why the flattening is needed.

Steven
  • 5,134
  • 2
  • 27
  • 38
0

I have the same question. So I try with different end_dims. like:

data = torch.randn((1, 7, 7, 4))
target = torch.randn((1, 7, 7, 4))

loss = torch.nn.MSELoss(reduction="sum")


object_loss = loss(
        torch.flatten(data, end_dim=-2),
        torch.flatten(target, end_dim=-2),
    )
object_loss1 = loss(
        torch.flatten(data, end_dim=-3),
        torch.flatten(target, end_dim=-3),
    )
print(object_loss)
print(object_loss1)

I got the same result. So I think it just helps to intepret MSE.