Search code examples
neural-networkpytorch

pytorch BCEWithLogitsLoss calculating pos_weight


I have a neural network as below for binary prediction. My classes are heavily imbalanced and class 1 occurs only 2% of times. Showing last few layers only

self.batch_norm2 = nn.BatchNorm1d(num_filters)

self.fc2 = nn.Linear(np.sum(num_filters), fc2_neurons)

self.batch_norm3 = nn.BatchNorm1d(fc2_neurons)

self.fc3 = nn.Linear(fc2_neurons, 1)

My loss is as below. Is this a correct way to calculate pos_weight parameter? I looked into official documentation at this link and it shows that pos_weight needs to have one value for each class for multiclass classification. Not sure if for the binary class it is a difference scenario. I tried to input 2 values and I was getting an error

My question: for binary problem, would pos_weight be a single value unlike multiclass classification where it needs to a list/array with length equal to number of classes?

BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=class_wts[0]/class_wts[1])

My y variable is a single variable that has 0 or 1 to represent the actual class and the neural network outputs a single value

--------------------------------------------------Update 1

based upon the answer by Shai I have below questions:

  1. BCEWithLogitsLoss - if it is a multiclass problem then how to use pos_weigh parameter?
  2. Is there any example of using focal loss in pytorch? I found some links but most of them were old - dating 2 or 3 or more years
  3. For training I am oversampling my class 1. Is focal loss still appropiate?

Solution

  • The documentation of pos_weight is indeed a bit unclear. For BCEWithLogitsLoss pos_weight should be a torch.tensor of size=1:

    BCE_With_LogitsLoss=nn.BCEWithLogitsLoss(pos_weight=torch.tensor([class_wts[0]/class_wts[1]]))
    

    However, in your case, where pos class occurs only 2% of the times, I think setting pos_weight will not be enough.
    Please consider using Focal loss:
    Tsung-Yi Lin, Priya Goyal, Ross Girshick, Kaiming He, Piotr Dollár Focal Loss for Dense Object Detection (ICCV 2017).
    Apart from describing Focal loss, this paper provides a very good explanation as to why CE loss performs so poorly in the case of imbalance. I strongly recommend reading this paper.

    Other alternatives are listed here.