Search code examples
deep-learningcomputer-visionpytorchimage-segmentationloss-function

Compare two segmentation maps predictions


I am using consistency between two predicted segmentation maps on unlabeled data. For labeled data, I’m using nn.BCEwithLogitsLoss and Dice Loss.

I’m working on videos that’s why 5 dimensions output. (batch_size, channels, frames, height, width)

I want to know how can we compare two predicted segmentation maps.

gmentation maps.

# gt_seg - Ground truth segmentation map. - (8, 1, 8, 112, 112)
# aug_gt_seg - Augmented ground truth segmentation map - (8, 1, 8, 112, 112)

predicted_seg_1 = model(data, targets)       # (8, 1, 8, 112, 112)
predicted_seg_2 = model(augmented_data, augmented_targets) #(8, 1, 8, 112, 112)

# define criterion
seg_criterion_1 = nn.BCEwithLogitsLoss(size_average=True)
seg_criterion_2 = nn.DiceLoss()

# labeled losses
supervised_loss_1 = seg_criterion_1(predicted_seg_1, gt_seg)
supervised_loss_2 = seg_criterion_2(predicted_seg_1, gt_seg)

# Consistency loss
if consistency_loss == "l2":
      consistency_criterion = nn.MSELoss()
      cons_loss = consistency_criterion(predicted_gt_seg_1, predicted_gt_seg_2)

elif consistency_loss == "l1":
      consistency_criterion = nn.L1Loss()
      cons_loss = consistency_criterion(predicted_gt_seg_1, predicted_gt_seg_2)

total_supervised_loss = supervised_loss_1 + supervised_loss_2
total_consistency_loss = cons_loss

Is this the right way to apply consistency between two predicted segmentation maps?

I’m mainly confused due to the definition on the torch website. It’s a comparison with input x with target y. I thought it looks correct since I want both predicted segmentation maps similar. But, 2nd segmentation map is not a target. That’s why I’m confused. Because if this could be valid, then every loss function can be applied in some or another way. That doesn’t look appealing to me. If it’s the correct way to compare, can it be extended to other segmentation-based losses such as Dice Loss, IoU Loss, etc.?

One more query regarding loss computation on labeled data:

# gt_seg - Ground truth segmentation map
# aug_gt_seg - Augmented ground truth segmentation map

predicted_seg_1 = model(data, targets)
predicted_seg_2 = model(augmented_data, augmented_targets)

# define criterion
seg_criterion_1 = nn.BCEwithLogitsLoss(size_average=True)
seg_criterion_2 = nn.DiceLoss()

# labeled losses
supervised_loss_1 = seg_criterion_1(predicted_seg_1, gt_seg)
supervised_loss_2 = seg_criterion_2(predicted_seg_1, gt_seg)

# augmented labeled losses
aug_supervised_loss_1 = seg_criterion_1(predicted_seg_2, aug_gt_seg)
aug_supervised_loss_2 = seg_criterion_2(predicted_seg_2, aug_gt_seg)

total_supervised_loss = supervised_loss_1 + supervised_loss_2 + aug_supervised_loss_1 + aug_supervised_loss_2

Is the calculation of total_supervised_loss correct? Can I apply loss.backward() on this?


Solution

  • Yes, this is a valid way to implement consistency loss. The nomenclature used by pytorch documentation lists one input as the target and the other as the prediction, but consider that L1, L2, Dice, and IOU loss are all symmetrical (that is, Loss(a,b) = Loss(b,a)). So any of these functions will accomplish a form of consistency loss with no regard for whether one input is actually a ground-truth or "target".