These are some lines from my loss function. output
is the output of a multiclass classification network.
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
I want dr_output.sum()
to be part of my loss function. But there are many limitations in my implementation. Some functions are non-differentiable in pytorch, and also dr_output
may be zero which is also not allowed if I only use dr_output
as my loss. Can anyone please suggest to me a way around these problems?
If I got it correctly:
bin_count=torch.bincount(torch.where(output>.1)[0], minlength=output.shape[0])
computes how many elements are greater than .1
, for each row.
Instead:
dr_output = (bin_count == 1) & (torch.argmax(output, dim=1)==labels)
is true if there is only an element greater than .1
in the corresponding row, and the prediction is correct.
dr_output.sum()
then counts how many rows verify this condition, so minimizing the loss may enforce incorrect predictions or distributions with more values greater than .1
.
Given these considerations, you can approximate your loss with the following:
import torch.nn.functional as F
# x are the inputs, y the labels
mask = x > 0.1
p = F.softmax(x, dim=1)
out = p * (mask.sum(dim=1, keepdim=True) == 1)
loss = out[torch.arange(x.shape[0]), y].sum()
You can devise similar variants that are more fit for your problem.