I try to fight with overfitting, this is why I decided to look through documentation (https://pytorch-lightning.readthedocs.io/en/stable/common/evaluation_basic.html#train-with-the-validation-loop), where I found that you can pass in Trainer.fit training and validation dataloader. The question is that - should I use this method, or I can simply pass the dataloader class in Trainer.fit to prevent overfitting ?
Code DataLoader:
class ClassifierDataModule(pl.LightningDataModule):
def __init__(self, train_dataset:pd.DataFrame, val_dataset:pd.DataFrame, batch_size:int):
super().__init__()
self.prepare_data_per_node = False
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.batch_size=batch_size
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True, num_workers=os.cpu_count())
data_module_classifier = ClassifierDataModule(train_dataset,val_dataset,test_dataset,BATCH_SIZE )
And here is my Trainer.fit():
model = MulticlassClassificationLIGHT(class_weights)
#trainer.fit(model, data_module_classifier) # SHOULD I USE THIS METHOD TO PREVENT OVERFITTING
trainer.fit(model, data_module_classifier.train_dataloader(),data_module_classifier.val_dataloader() ) # OR THIS ONE ?
My LightningModule just in case:
class MulticlassClassificationLIGHT(pl.LightningModule):
def __init__(self,class_weights):
super(MulticlassClassificationLIGHT, self).__init__()
self.num_feature=35
self.num_class=36
self.layer_1 = nn.Linear(self.num_feature, 512)
self.layer_2 = nn.Linear(512, 128)
self.layer_3 = nn.Linear(128, 64)
self.layer_out = nn.Linear(64, self.num_class)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.batchnorm1 = nn.BatchNorm1d(512)
self.batchnorm2 = nn.BatchNorm1d(128)
self.batchnorm3 = nn.BatchNorm1d(64)
self.loss = nn.CrossEntropyLoss(weight=class_weights.to(device))
def forward(self, x):
x = self.layer_1(x)
x = self.batchnorm1(x)
x = self.relu(x)
x = self.layer_2(x)
x = self.batchnorm2(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer_3(x)
x = self.batchnorm3(x)
x = self.relu(x)
x = self.dropout(x)
x = self.layer_out(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.loss(logits, y)
self.log("train_loss", loss, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = self.loss(logits, y)
self.log("val_loss", loss, prog_bar=True, logger=True) # I ask Trainer to "ModelCheckpoint" this loss
return loss
Passing validation data loader during training does not fix overfitting. It allows to measure the overfitting/underfitting of the model. We want performance on validation data to be closer to performance on training data in case of a well-fit model.
Regarding the syntax, This should work :
trainer.fit(model=model, train_dataloaders =data_module_classifier.train_dataloader(), val_dataloaders =data_module_classifier.val_dataloader())
documentation for fit here - https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#trainer-class-api