Search code examples
pythondeep-learningpytorchclassification

Target and output shape/type for binary classification using PyTorch


so I have some annotated images that I want to use to train a binary image classifier but I have been having issues creating the dataset and actually getting a test model to train. Each image is either of a certain class or not so I want to set up a binary classification dataset/model using PyTorch. I had some questions:

  1. should labels be float or long?
  2. what shape should my labels be?
  3. I am using a resnet18 class from torchvision model, should my final softmax layer have one or two outputs?
  4. what shapes should my target be, during training, if my batch size is 200?
  5. what shape should my outputs be?

Thanks in advance

Quote Delete


Solution

  • Binary classification is slightly different than multi-label classification: while for multilabel your model predicts a vector of "logits", per sample, and uses softmax to converts the logits to probabilities; In the binary case, the model predicts a scalar "logit", per sample, and uses the sigmoid function to convert it to class probability.

    In the softmax and the sigmoind are "folded" into the loss layer (for numerical stability considerations) and therefore there are different Cross Entropy loss layers for the two cases: nn.BCEWithLogitsLoss for the binary case (with sigmoid) and nn.CrossEntropyLoss for the multilabel case (with softmax).

    In your case you want to use the binary version (with sigmoid): nn.BCEWithLogitsLoss.
    Thus your labels should be of type torch.float32 (same float type as the output of the network) and not integers. You should have a single label per sample. Thus, if your batch size is 200, the target should have shape (200,1).


    I'll leave it here as an exercise to show that training a model with two outputs and CE+softmax is equivalent to binary output+sigmoid ;)