I want to compute retrieval metric in distributed setting with torch lightning.
Retrieval metrics compute metrics over groups of inputs. This is a problem in distributed setting, because the groups get split between different batches, so just logging the metric in validation step will result in incorrect output. Therefore, I need to aggregate all batch inputs and predictions first before computing the final metric.
However, as of now, lightning 2.4.0 does not have any out of the box support for this use case. The documentation only briefly mentions, that you need to aggregate the results yourself into a mutable list:
class LightningTransformer(L.LightningModule):
def __init__(self, vocab_size):
...
self.validation_step_outputs = []
def validation_step(self, batch, batch_idx, dataloader_ix):
x, y, group = batch
preds = self.model(x, y)
self.validation_step_outputs.append([y,group,preds])
return loss
def on_validation_epoch_end(self):
y, group, preds = torch.stack(self.validation_step_outputs)
metric = self.retrival_metric(preds, y, group)
self.validation_step_outputs.clear() # free memory
However, this approach does not account for multiple gpus and multiple validation datasets. Also, it seems awfully convoluted for something that should be supported out of the box.
What is the correct way to compute eg. torchmetrics.RetrievalNormalizedDCG
for several validation datasets on 2+ devices?
You need to aggregate the predictions and targets across all devices before calculating the metric. note that, as far as I know, lightning doesn't directly support distributed aggregation out of the box for such complex metrics, it can be done manually by gathering the results across GPUs.
Here's the implementation I suggest:
class LightningTransformer(L.LightningModule):
def __init__(self, vocab_size):
super().__init__()
self.model = Transformer(vocab_size=vocab_size)
self.validation_step_outputs = []
self.val_retrieval_metric = torchmetrics.RetrievalNormalizedDCG(k=5)
def validation_step(self, batch, batch_idx, dataloader_idx=0):
inputs, target = batch
output = self.model(inputs)
# Calculate loss or other required values
loss = torch.nn.functional.nll_loss(output, target.view(-1))
# (I just feel this out but this will depend on your model/task)
pred = torch.argmax(output, dim=-1)
# Append the predictions and targets for aggregation
self.validation_step_outputs.append({'preds': pred, 'target': target})
return loss
def on_validation_epoch_end(self):
# Gather outputs from all devices
all_preds, all_targets = [], []
for output in self.validation_step_outputs:
# Gather predictions and targets across devices
preds = self.all_gather(output['preds'])
targets = self.all_gather(output['target'])
# Append gathered data for the whole epoch
all_preds.append(preds)
all_targets.append(targets)
# Concatenate all predictions and targets
all_preds = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
# Compute retrieval metric on all gathered data
ndcg = self.val_retrieval_metric(all_preds, all_targets)
self.log("val_retrieval_ndcg", ndcg)
self.validation_step_outputs.clear()
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-4)