Search code examples
pytorchgradientbackpropagation

is there any way to include a counter(a variable that count something) in a loss function in pytorch?


These are some lines from my loss function. output is the output of a multiclass classification network.

bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])

dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)

I want dr_output.sum() to be part of my loss function. But there are many limitations in my implementation. Some functions are non-differentiable in pytorch, and also dr_output may be zero which is also not allowed if I only use dr_output as my loss. Can anyone please suggest to me a way around these problems?


Solution

  • If I got it correctly:

    bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
    

    computes how many elements are greater than .1, for each row.

    Instead:

    dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
    

    is true if there is only an element greater than .1 in the corresponding row, and the prediction is correct.

    dr_output.sum() then counts how many rows verify this condition, so minimizing the loss may enforce incorrect predictions or distributions with more values greater than .1.

    Given these considerations, you can approximate your loss with the following:

    import torch.nn.functional as F
    
    # x are the inputs, y the labels
    
    mask = x > 0.1
    p = F.softmax(x, dim=1)
    out = p * (mask.sum(dim=1, keepdim=True) == 1)
    
    loss = out[torch.arange(x.shape[0]), y].sum()
    

    You can devise similar variants that are more fit for your problem.