Googling this gets you no where, so I decided to help future me and others by posting this as a searchable question.
def __init__():
...
self.val_acc = pl.metrics.Accuracy()
def validation_step(self, batch, batch_index):
...
self.val_acc.update(log_probs, label_batch)
gives
ValueError: preds and target must have same number of dimensions, or one additional dimension for preds
for log_probs.shape == (16, 4)
and for label_batch.shape == (16, 4)
What's the issue?
pl.metrics.Accuracy()
expects a batch of dtype=torch.long
labels, not one-hot encoded labels.
Thus, it should be fed
self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))
This is just the same as torch.nn.CrossEntropyLoss