Search code examples
torchautoencoderlosspytorch

How to use the BCELoss in PyTorch?


I want to write a simple autoencoder in PyTorch and use BCELoss, however, I get NaN out, since it expects the targets to be between 0 and 1. Could someone post a simple use case of BCELoss?


Solution

  • Update

    The BCELoss function did not use to be numerically stable. See this issue https://github.com/pytorch/pytorch/issues/751. However, this issue has been resolved with Pull #1792, so that BCELoss is numerically stable now!


    Old answer

    If you build PyTorch from source, you can use the numerically stable function BCEWithLogitsLoss(contributed in https://github.com/pytorch/pytorch/pull/1792), which takes logits as input.

    Otherwise, you can use the following function (contributed by yzgao in the above issue):

    class StableBCELoss(nn.modules.Module):
           def __init__(self):
                 super(StableBCELoss, self).__init__()
           def forward(self, input, target):
                 neg_abs = - input.abs()
                 loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
                 return loss.mean()