Search code examples
pythondeep-learningpytorch

Issue between number of classes and shape of inputs in metric collection torch


I have a problem because I want to calculate some metrics in torchmetrics. but there is a problem:

ValueError: The implied number of classes (from shape of inputs) does not match num_classes.

The output is from UNet and the loss function is BCEWithLogitsLoss (binary segmentation)

channels = 1 because of grayscale img

Input shape: (batch_size, channels, h, w) torch.float32

Label shape: (batch_size, channels, h, w) torch.float32 for BCE

Output shape: (batch_size, channels, h, w): torch.float32

inputs, labels = batch
outputs = model(input)
loss = self.loss_function(outputs, labels)
prec = torchmetrics.Precision(num_classes=1)(outputs, labels.type(torch.int32)

Solution

  • It seems that torchmetrics expects different shape. Try to flatten both output and labels:

    prec = torchmetrics.Precision(num_classes=1)(outputs.view(-1), labels.type(torch.int32).view(-1))