I have read some papers that use something called "Bootstrapped Cross Entropy Loss" to train their segmentation network. The idea is to focus only on the hardest k% (say 15%) of the pixels into account to improve learning performance, especially when easy pixels dominate.
Currently, I am using the standard cross entropy:
loss = F.binary_cross_entropy(mask, gt)
How do I convert this to the bootstrapped version efficiently in PyTorch?
Often we would also add a "warm-up" period to the loss such that the network can learn to adapt to the easy regions first and transit to the harder regions.
This implementation starts from k=100
and continues for 20000 iterations, then linearly decay it to k=15
for another 50000 iterations.
class BootstrappedCE(nn.Module):
def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15):
super().__init__()
self.start_warm = start_warm
self.end_warm = end_warm
self.top_p = top_p
def forward(self, input, target, it):
if it < self.start_warm:
return F.cross_entropy(input, target), 1.0
raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
num_pixels = raw_loss.numel()
if it > self.end_warm:
this_p = self.top_p
else:
this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm))
loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
return loss.mean(), this_p