Search code examples
pythonpytorchloss-functionvision-transformer

I want to change my VIT from single-label multi-class classification to multi-label. How should I rewrite the evaluation and loss sections?


The new labels look like: [0, 1, 1, 0, 0, 1, 0]. The original loss function: torch.nn.CrossEntropyLoss() The calculation segment: pred = model(images.to(device)) loss = loss_function(pred, labels.to(device)) (How to use torch.nn.BCEWithLogitsLoss to replace that?

I have gotten some answers from GPTs and Google without executable details.


Solution

  • If you have your labels in that format already, you can just swap the loss function.

    import torch
    import torch.nn as nn
    
    loss_fn = nn. BCEWithLogitsLoss()
    
    logits = torch.randn(3)
    labels = torch.tensor([1, 0, 1]).float()
    
    loss = loss_fn(logits, labels)