Search code examples
pytorchimage-segmentationloss-functiondicesemantic-segmentation

I want to confirm which of these methods to calculate Dice Loss is correct


so I have 4 methods to calculate dice loss and 3 of them are returning the same results, so I can conclude that 1 of them is calculating it wrong, but I would to confirm it with you guys:

import torch
torch.manual_seed(0)

inputs = torch.rand((3,1,224,224))
target = torch.rand((3,1,224,224))

Method 1: flatten tensors

def method1(inputs, target):

    inputs = inputs.reshape( -1)

    target = target.reshape( -1)

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()

    print("method1", dice)

Method 2: flatten tensors except for batch size, sum all dims

def method2(inputs, target):
    num = target.shape[0]
    inputs = inputs.reshape(num, -1)

    target = target.reshape(num, -1)

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()/num

    print("method2", dice)

Method 3: flatten tensors except for batch size, sum dim 1

def method3(inputs, target):
    num = target.shape[0]
    inputs = inputs.reshape(num, -1)

    target = target.reshape(num, -1)

    intersection = (inputs * target).sum(1)
    union = inputs.sum(1) + target.sum(1)
    dice = (2. * intersection) / (union + 1e-8)
    dice = dice.sum()/num

    print("method3", dice)

Method 4: don't flatten tensors

def method4(inputs, target):

    intersection = (inputs * target).sum()
    union = inputs.sum() + target.sum()
    dice = (2. * intersection) / (union + 1e-8)


    print("method4", dice)

method1(inputs, target)
method2(inputs, target)
method3(inputs, target)
method4(inputs, target)

method 1,3 and 4 print: 0.5006 method 2 print: 0.1669

and it makes sense, since I am flattening the inputs and targets on 3 dimensions leaving out batch size, and then I am summing all 2 dimensions that result from the flattening instead of just dim 1

Method 4 seems to be the most optimized one


Solution

  • First, you need to decide what dice score you report: the dice score of all samples in the batch (methods 1,2 and 4) or the averaged dice score of each sample in the batch (method 3).
    If I'm not mistaken, you want to use method 3 - you want to optimize the dice score of each of the samples in the batch and not a "global" dice score: Suppose you have one "difficult" sample in an "easy" batch. The misclassified pixels of the "difficult" sample will be negligible w.r.t all other pixels. But if you look at the dice score of each sample separately then the dice score of the "difficult" sample will not be negligible.