Search code examples
pytorchloss-function

Custom distance loss function in Pytorch?


I want to implement the following distance loss function in pytorch. I was following this https://discuss.pytorch.org/t/custom-loss-functions/29387/4 thread from the pytorch forum

np.linalg.norm(output - target)
# where output.shape = [1, 2] and target.shape = [1, 2]

So I have implemented the loss function like this

def my_loss(output, target):    
    loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))
    return loss

with this loss function, calling backwards gives runtime error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

My entire code looks like this

model = nn.Linear(2, 2)

x = torch.randn(1, 2)
target = torch.randn(1, 2)
output = model(x)

loss = my_loss(output, target)
loss.backward()   <----- Error here

print(model.weight.grad)

PS: I am aware of the pairwise loss of pytorch but due to some limitation of it, I have to implement it myself.

Following the pytorch source code I have tried the following,

class my_function(torch.nn.Module): # forgot to define backward()
    def forward(self, output, target):

        loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))
        return loss

model = nn.Linear(2, 2)
x = torch.randn(1, 2)
target = torch.randn(1, 2)
output = model(x)

criterion = my_function()

loss = criterion(output, target)


loss.backward()
print(model.weight.grad)

And I get the Run time error

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How can I implement the loss function correctly?


Solution

  • This happens because, in the loss function, you are detaching tensors. You had to detach because you wanted to use np.linalg.norm. This breaks the graph and you get the error that tensors don't have grad fn.

    You can replace

    loss = torch.tensor(np.linalg.norm(output.detach().numpy() - target.detach().numpy()))

    by torch operations as

    loss = torch.norm(output-target)

    This should work fine.