Search code examples
pythonneural-networkpytorchgradient-descentloss-function

My custom loss function in Pytorch does not train


My custom loss function in Pytorch does not update during training. The loss stays exactly the same. I am trying to write this custom loss function based on the false positive and negative rates. I am giving you a simplified version of the code. Any idea what could be happening? Does the backpropagation turns to 0? Is this not the correct way of defining a custom loss function?

I have already checked that during backpropagation the Gradient always stays TRUE (assert requires_grad). I have also tried to make a class (torch.nn.module) of the function false_pos_neg_rate, but that did not work. The Assert Requires_grad turned out to be negative and I left it out afterwards. There is no error, the training does continue.

def false_pos_neg_rate(outputs, truths):
    y = truths
    y_predicted = outputs
    cut_off= torch.tensor(0.5, requires_grad=True)
    y_predicted =torch.where(y_predicted <= cut_off, zeros, ones)
    tp, fp, tn, fn = confusion_matrix(y_predicted, y)
    fp_rate = fp / (fp+tn).float()
    fn_rate = fn / (fn+tp).float()
    loss = fn_rate + fp_rate
    return loss

for i, (samples, truths) in enumerate(train_loader):
    samples = Variable(samples)
    truths = Variable(truths)    
    outputs = model(samples) 
    loss = false_pos_neg_rate_torch(outputs, truths)
    loss.backward()                  
    optimizer.step()

I expect the loss function to update the model and be smaller every training step. Instead the loss stays exactly the same and nothing happens.

Please help me, what happens? Why does the model not train during training steps?


Solution

  • As pointed out by Umang Gupta your loss function is not differentiable. If you write, mathematically, what you are trying to do you'll see that your loss has zero gradient almost everywhere and it behaves like a "step function".
    In order to train models using gradient-descent methods you must have meaningful gradients for the loss function.