Search code examples
deep-learningneural-networkpytorchloss-function

How do I compute bootstrapped cross entropy loss in PyTorch?


I have read some papers that use something called "Bootstrapped Cross Entropy Loss" to train their segmentation network. The idea is to focus only on the hardest k% (say 15%) of the pixels into account to improve learning performance, especially when easy pixels dominate.

Currently, I am using the standard cross entropy:

loss = F.binary_cross_entropy(mask, gt)

How do I convert this to the bootstrapped version efficiently in PyTorch?


Solution

  • Often we would also add a "warm-up" period to the loss such that the network can learn to adapt to the easy regions first and transit to the harder regions.

    This implementation starts from k=100 and continues for 20000 iterations, then linearly decay it to k=15 for another 50000 iterations.

    class BootstrappedCE(nn.Module):
        def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
            super().__init__()
    
            self.start_warm = start_warm
            self.end_warm = end_warm
            self.top_p = top_p
    
        def forward(self, input, target, it):
            if it < self.start_warm:
                return F.cross_entropy(input, target), 1.0
    
            raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
            num_pixels = raw_loss.numel()
    
            if it > self.end_warm:
                this_p = self.top_p
            else:
                this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
            loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
            return loss.mean(), this_p