Search code examples
pythonpytorchloss-functionbert-language-modelcross-entropy

focal loss for imbalanced data using pytorch


I want to use focal loss with multiclass imbalanced data using pytorch . I searched got and try to use this code but I got error


class_weights=tf.constant([0.21, 0.45, 0.4, 0.46, 0.48, 0.49])

loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

and use this in train function



    preds = model(sent_id, mask, labels)
   
     # compu25te the validation loss between actual and predicted values
    alpha=0.25
    gamma=2
    ce_loss = loss_fn(preds, labels)
    pt = torch.exp(-ce_loss)
    focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean()

the error is

TypeError: cannot assign 'tensorflow.python.framework.ops.EagerTensor' object to buffer 'weight' (torch Tensor or None required)

in this line

loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

Solution

  • You're mixing tensorflow and pytorch objects.

    Try:

    class_weights=torch.tensor([0.21, ...], requires_grad=False)