I know that for classification using a neural network and CrossEntropy Loss, we need one-hot encoded output, but in PyTorch the CrossEntropy loss does not accept one-hot encoded targets, and we should give it the labels, directly and in the normal format.
Now, I am wondering if this is the same for image segmentation tasks, where the loss function is the dice loss or focal loss, etc. i.e. Is it ok if I one-hot encode the target mask for segmentation similar to tensorflow, or I cannot do that similar to classification task in Pytorch? (Lets say I am going to use a standard CNN, such as 3DUNet)
Extra note: I have done classification with Pytorch using neural nets, it did not make sense that the cross enropy loss could not accept a one-hot encoded input. Now, I expect that for other losses that are used for segmentation, this should be the same as cross entropy loss, yet I am not sure as it won't make any sense.
For classification, it is possible for CrossEntropyLoss()
to handle one-hot targets, as specified in the doc
The target that this criterion expects should contain either: class indices in the range [0,C), where C is the number of classes; or Probabilities for each class.
Therefore, you have two options for your one-hot targets:
torch.argmax()
to convert one-hot targets to class indices.For example:
cross_entropy = nn.CrossEntropyLoss()
input = torch.randn(3, 5)
print(f"input: {input}")
# input: tensor([[-0.6498, -0.4508, 1.0618, -1.4337, -1.6479],
# [-0.9778, 0.0141, -0.5646, -1.0664, 0.9022],
# [ 0.0797, 0.7878, 0.6092, -0.2396, -0.5839]])
random_idx = torch.randint(0, 5, (3,))
print(f"random_idx: {random_idx}")
# random_idx: tensor([2, 3, 3])
target_one_hot = torch.eye(5)[random_idx]
output_one_hot = cross_entropy(input, target_one_hot)
print(f"target_one_hot: {target_one_hot}")
print(f"output_one_hot: {output_one_hot}")
# target_one_hot: tensor([[0., 0., 1., 0., 0.],
# [0., 0., 0., 1., 0.],
# [0., 0., 0., 1., 0.]])
# output_one_hot: 1.7241638898849487
target = target_one_hot.argmax(dim=1)
output = cross_entropy(input, target)
print(f"target: {target}")
print(f"output: {output}")
# target: tensor([2, 3, 3])
# output: 1.7241638898849487
For segmentation, as I explained above, it is possible to use CrossEntropyLoss()
to handle one-hot targets. For Dice
loss, since there is no official implementation in PyTorch and you need to implement it by yourself, and thus you can define targets as you like (one-hot or others).
Maybe you can check this GitHub code which implements the classical U-Net and train the model using both CrossEntropyLoss
and DiceLoss