Search code examples
pythonpytorch-lightning

How to validate with RetrievalMetric in Lightning on multiple GPU?


I want to compute retrieval metric in distributed setting with torch lightning.

Retrieval metrics compute metrics over groups of inputs. This is a problem in distributed setting, because the groups get split between different batches, so just logging the metric in validation step will result in incorrect output. Therefore, I need to aggregate all batch inputs and predictions first before computing the final metric.

However, as of now, lightning 2.4.0 does not have any out of the box support for this use case. The documentation only briefly mentions, that you need to aggregate the results yourself into a mutable list:

class LightningTransformer(L.LightningModule):
    def __init__(self, vocab_size):
        ...
        self.validation_step_outputs = []

    def validation_step(self, batch, batch_idx, dataloader_ix):
        x, y, group = batch
        preds = self.model(x, y)
        self.validation_step_outputs.append([y,group,preds])

        return loss

    def on_validation_epoch_end(self):
        y, group, preds = torch.stack(self.validation_step_outputs)

        metric = self.retrival_metric(preds, y, group)
        self.validation_step_outputs.clear()  # free memory

However, this approach does not account for multiple gpus and multiple validation datasets. Also, it seems awfully convoluted for something that should be supported out of the box.

What is the correct way to compute eg. torchmetrics.RetrievalNormalizedDCG for several validation datasets on 2+ devices?


Solution

  • You need to aggregate the predictions and targets across all devices before calculating the metric. note that, as far as I know, lightning doesn't directly support distributed aggregation out of the box for such complex metrics, it can be done manually by gathering the results across GPUs.

    Here's the implementation I suggest:

    class LightningTransformer(L.LightningModule):
        def __init__(self, vocab_size):
            super().__init__()
            self.model = Transformer(vocab_size=vocab_size)
            self.validation_step_outputs = []
            self.val_retrieval_metric = torchmetrics.RetrievalNormalizedDCG(k=5)
    
        def validation_step(self, batch, batch_idx, dataloader_idx=0):
            inputs, target = batch
            output = self.model(inputs)
            
            # Calculate loss or other required values
            loss = torch.nn.functional.nll_loss(output, target.view(-1))
            
            # (I just feel this out but this will depend on your model/task)
            pred = torch.argmax(output, dim=-1)
            
            # Append the predictions and targets for aggregation
            self.validation_step_outputs.append({'preds': pred, 'target': target})
            
            return loss
    
        def on_validation_epoch_end(self):
            # Gather outputs from all devices
            all_preds, all_targets = [], []
    
            for output in self.validation_step_outputs:
                # Gather predictions and targets across devices
                preds = self.all_gather(output['preds'])
                targets = self.all_gather(output['target'])
    
                # Append gathered data for the whole epoch
                all_preds.append(preds)
                all_targets.append(targets)
    
            # Concatenate all predictions and targets
            all_preds = torch.cat(all_preds, dim=0)
            all_targets = torch.cat(all_targets, dim=0)
            
            # Compute retrieval metric on all gathered data
            ndcg = self.val_retrieval_metric(all_preds, all_targets)
    
            self.log("val_retrieval_ndcg", ndcg)
    
            self.validation_step_outputs.clear()
        
        def configure_optimizers(self):
            return torch.optim.Adam(self.model.parameters(), lr=1e-4)