Search code examples

Pytorch nn.MSELoss without specifying target

I was having difficulty with my loss getting stuck at a particular value. It would always decrease to a certain value, then stop decreasing. The code regarding the loss was:

criterion = nn.MSELoss()
loss = criterion(y_pred, y_batch.unsqueeze(1))

When I changed it to:

criterion = nn.MSELoss()
loss = criterion(y_pred, target=y_batch)

the issue was fixed.

What was happening before when the target was not specified? Does the target need to be specified for every Pytorch loss function? I found nothing in the documentation about target specifications.


  • It looks like target is the name of the second positional argument, that's all. The only difference between the two lines is the unsqueezing of dim=1 on the second one.