The goal is to minimize the angle between the actual and predicted vectors in a neural network setting. Can someone please check if the following execution is correct?
criterion = nn.CosineSimilarity()
loss = torch.mean(torch.abs(criterion(actual_vectors,predicted_vectors)))
#back-propagation on the above *loss* will try cos(angle) = 0. But I want angle between the vectors to be 0 or cos(angle) = 1.
loss = 1 - loss
#To me, the above does not seem right. Isn't back-propagation on the above 'loss' similar to minimizing the negative of 'loss' from line 2?
#Does '1' have no role to play here when back-propagation is applied?
loss.backward()
Theoretically that makes sense. The goal of back-propagation is to minimize the loss. If the loss is 1 - cos(A)
(where A is the angle difference between the two) then that is equivalent to saying that the goal is to maximize cos(A)
, which in turn is equivalent to minimizing the Angle between the two vectors.
A simple example would be the goal of minimizing X^2 + 4
the answer to that optimization problem is the same as the answer to the goal of maximizing -(X^2 + 4)
. Sticking a minus on the whole equation and swapping min with max would make the statements equivalent. So if you have a function you want to MAXIMIZE and your optimization model can only MINIMIZE then just slap a minus sign on your function and call it a day.
Another question you might ask is "what is significant about the 1? Could we have just said loss = -loss
" and the answer is... it depends. Theoretically yes that is equivalent and the one doesn't play a role in the backward propagation (since it disappears with the derivative). However, once we start talking about actual optimization with numerical errors and complicated optimizers/update rules then the constant 1 might play a role.
Another reason to have the 1 is so that your loss is nicely defined between 0 and 1 which is a nice property to have.
So yes, minimizing the loss of 1 - cos(A)
through back-propagation is equivalent to minimizing the angle between the vectors.