Search code examples
pythondeep-learningpytorchdata-sciencepytorch-lightning

LightningDataModule with Trainer in PytorchLightning automatically fits validation model?


I try to fight with overfitting, this is why I decided to look through documentation (https://pytorch-lightning.readthedocs.io/en/stable/common/evaluation_basic.html#train-with-the-validation-loop), where I found that you can pass in Trainer.fit training and validation dataloader. The question is that - should I use this method, or I can simply pass the dataloader class in Trainer.fit to prevent overfitting ?

Code DataLoader:

class ClassifierDataModule(pl.LightningDataModule):
    
    def __init__(self, train_dataset:pd.DataFrame, val_dataset:pd.DataFrame, batch_size:int):
      super().__init__()
      self.prepare_data_per_node = False

      self.train_dataset = train_dataset
      self.val_dataset = val_dataset
      self.batch_size=batch_size

        
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())

data_module_classifier = ClassifierDataModule(train_dataset,val_dataset,test_dataset,BATCH_SIZE )

And here is my Trainer.fit():

model = MulticlassClassificationLIGHT(class_weights)
#trainer.fit(model, data_module_classifier) # SHOULD I USE THIS METHOD TO PREVENT OVERFITTING 
trainer.fit(model, data_module_classifier.train_dataloader(),data_module_classifier.val_dataloader() ) # OR THIS ONE ?

My LightningModule just in case:

class MulticlassClassificationLIGHT(pl.LightningModule):
    def __init__(self,class_weights):
        super(MulticlassClassificationLIGHT, self).__init__()
        
        self.num_feature=35
        self.num_class=36
        
        self.layer_1 = nn.Linear(self.num_feature, 512)
        self.layer_2 = nn.Linear(512, 128)
        self.layer_3 = nn.Linear(128, 64)
        self.layer_out = nn.Linear(64, self.num_class) 
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.2)
        self.batchnorm1 = nn.BatchNorm1d(512)
        self.batchnorm2 = nn.BatchNorm1d(128)
        self.batchnorm3 = nn.BatchNorm1d(64)

        self.loss = nn.CrossEntropyLoss(weight=class_weights.to(device)) 



    def forward(self, x):
        x = self.layer_1(x)
        x = self.batchnorm1(x)
        x = self.relu(x)
        
        x = self.layer_2(x)
        x = self.batchnorm2(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.layer_3(x)
        x = self.batchnorm3(x)
        x = self.relu(x)
        x = self.dropout(x)
        
        x = self.layer_out(x)
        
        return x

        
    def training_step(self, batch, batch_idx):
        x, y = batch 
        logits = self.forward(x) 
        loss = self.loss(logits, y) 
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss 

    def validation_step(self, batch, batch_idx):
        x, y = batch 
        logits = self.forward(x) 
        loss = self.loss(logits, y)
        self.log("val_loss", loss, prog_bar=True, logger=True) # I ask Trainer to "ModelCheckpoint" this loss
        return loss

Solution

  • Passing validation data loader during training does not fix overfitting. It allows to measure the overfitting/underfitting of the model. We want performance on validation data to be closer to performance on training data in case of a well-fit model.

    Regarding the syntax, This should work :

    trainer.fit(model=model, train_dataloaders =data_module_classifier.train_dataloader(), val_dataloaders =data_module_classifier.val_dataloader())

    documentation for fit here - https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-class-api