Search code examples
pytorchpytorch-lightningmulti-gpu

Combining loss, predictions from multi gpus setting in pytorch lightning


Hi I'm facing an issue in gathering all the losses and predictions in multi gpu scenario. I'm using pytorch lightning 2.0.4 and deepspeed, distributed strategy - deepspeed_stage_2.

I'm adding my skeleton code here for reference.

    def __init__(self):
        self.batch_train_preds = []
        self.batch_train_losses = []


    def  training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']

        # Model Step
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=train_labels)

        train_preds = torch.argmax(outputs.logits, dim=-1)

        return {'loss': outputs[0],
                'train_preds': train_preds}

    def on_train_batch_end(self, outputs, batch, batch_idx):
        # aggregate metrics or outputs at batch level
        train_batch_loss = outputs["loss"].mean()
        train_batch_preds = torch.cat(outputs["train_preds"])

        self.batch_train_preds.append(train_batch_preds)
        self.batch_train_losses.append(train_batch_loss.item())

        return {'train_batch_loss': train_batch_loss,
                'train_batch_preds': train_batch_preds
                }

    def on_train_epoch_end(self) -> None:
        # Aggregate epoch level training metrics

        epoch_train_preds = torch.cat(self.batch_train_preds)
        epoch_train_loss = np.mean(self.batch_train_losses)

        self.logger.log_metrics({"epoch_train_loss": epoch_train_loss})

In the above code block, I'm trying to combine all the predictions into a single tensor at the end of the epoch by tracking each batch in a global list (defined at init). but in multi gpu training, I faced an error with concatination as each gpu is treating the batch in it's own device and I couldn't combine the results in a single global list.

Question:

What should I be doing in on_train_batch_end or on_train_epoch_end or in training_step in order to combine the results across all the gpus into a list created in my init because I want to calculate some additional metrics(precision, recall etc) during ON_*_EPOCH_END() function in my train, validation, test

(validation and test are exactly similar to my 3 training functions above i.e combining losses and predictions).

I have come across all_gather but it is being called across all devices(gpus) but comibining the results which I wanted.

Now the question is how do I use only one of the device's output from all_gather. A code snippet would be very much helpful.


Solution

  • lightning documentation suggests to use all_gather. Moreover, you do not need to manually accumulate the loss, just log it with self.log(..., epoch=True) to let lightning accumulate and log it correctly:

    class MyLightningModule(LightningModule):
    
        def __init__(self):
            super().__init__()
            self.batch_train_preds = []
    
        def training_step(self, batch, batch_idx):
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
    
            # Model Step
            outputs = self.model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
    
            loss = outputs[0]
    
            train_preds = torch.argmax(outputs.logits, dim=-1)
            self.batch_train_preds.append(train_preds)
    
            self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
            return loss
    
        def on_train_epoch_end(self) -> None:
    
            # Aggregate epoch level training metrics
            epoch_train_preds = torch.cat(self.batch_train_preds, dim=0)
    
            # the following will stack predictions from all the distributed processes on dim=0
            epoch_train_preds = self.all_gather(epoch_train_preds)
    
            # reshape to (dataset_size, *other_dims)
            new_batch_size = self.trainer.world_size() * epoch_train_preds.shape[0]
            epoch_train_preds = epoch_train_preds.view(new_batch_size, *epoch_train_preds.shape[1:])
    
            # compute here your metrics over `epoch_train_preds`
    
            self.batch_train_preds.clear()  # free memory 
    

    If you want to compute the metric only on a single process, protect the metric computation with if self.trainer.global_rank == 0:.

    I also suggest to take a look at torchmetrics, which enables automatic synchronisation of metrics in distributed setting with a few lines of code.

    Additionally, I've written a framework for easy training and testing of several Transformer models for NLP.

    Additional example using torchmetrics

    from torchmetrics.classification import BinaryAccuracy
    from lightning.pytorch import LightningModule
    
    
    class MyLightningModule(LightningModule):
    
        def __init__(self):
            super().__init__()
            self.train_accuracy = BinaryAccuracy()
    
        def training_step(self, batch, batch_idx):
            input_ids = batch['input_ids']
            attention_mask = batch['attention_mask']
            labels = batch['labels']
    
            # Model Step
            outputs = self.model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels
            )
    
            loss = outputs[0]
    
            train_preds = torch.argmax(outputs.logits, dim=-1)
            self.train_accuracy(train_preds, labels)  # updates the metric internal state with predictions and labels
    
            self.log('train/loss', loss, on_step=True, on_epoch=True, sync_dist=True)
            self.log('train/acc', self.train_accuracy, on_step=True, on_epoch=True, sync_dist=True)
            return loss
    
        def on_train_epoch_end(self) -> None:
            pass  # no need to reset the metric as lightning will take care of that after each epoch