I have trained a segmentation NN model for a binary classification problem in Pytorch Lightning. In order to achieve this I used BCEWithLogitsLoss. The shape for both my ground truth and predictions are (BZ, 640, 256) their content is (0, 1) [0, 1] respectively.
Now, I am trying to calculate the F1 score over batched data on my validation dataset with F1Score from torchmetrics and then accumulate with pytroch lightning's log_dict by
from torchmetrics import F1Score
self.f1 = F1Score(num_classes=2)
where my validation step looks like this:
def validation_step(self, batch, batch_idx):
t0, t1, mask_gt = batch
mask_pred = self.forward(t0, t1)
mask_pred = torch.sigmoid(mask_pred).squeeze()
mask_pred = torch.where(mask_pred > 0.5, 1, 0)
f1_score_ = self.f1(mask_pred, mask_gt)
metrics = {
'val_f1_score': f1_score_,
}
self.log_dict(metrics, on_epoch=True)
This gives me ridiculously high F1 scores at the end of each epoch (even on the sanity validation check before the training starts), ~0.99, which make me think that I am not using F1Score together with log_dict the right way. I have tried several arguments (https://github.com/PyTorchLightning/metrics/blob/master/torchmetrics/classification/f_beta.py#L181-L310) with no luck. What am I doing wrong?
It turns out I have en extremely unbalanced dataset where the "False" class is over-represented 40 times more than the "True" class. The model is very good at detecting the "False" class, hence the issues in detecting the "True" class are shadowed by taking the macro F1 average of both (with equal weights for each class).