I'm trying to reimplement a training pipeline on top of pytorch Lightning.
In the documentation they explain that training/validation loops are executed this way :
My understanding was that the order was :
I've implemented a dummy code in order to check this :
import pytorch_lightning as pl
from torchmetrics import MeanMetric, SumMetric
from torch.utils.data import Dataset,DataLoader
import torch
import warnings
warnings.filterwarnings("ignore")
class DummyDataset(Dataset):
def __init__(self):
pass
def __getitem__(self,idx):
return torch.zeros([3,12,12]),torch.ones([3,12,12]) # Dummy image Like...
def __len__(self):
return 50
class DummyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3,3,1,1) # Useless convolution
self.mean = MeanMetric()
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(),lr=1e-3)
def training_step(self, batch,batch_idx):
x,y=batch
y_hat = self(x)
loss = torch.sum((y-y_hat)**2)
self.mean.update(2)
return loss
def training_epoch_end(self, outputs):
mean_train = self.mean.compute()
print(f"\nmean_train is : {mean_train}\n")
self.mean.reset()
def validation_step(self, batch,batch_idx):
x,y=batch
y_hat = self(x)
loss = torch.sum((y-y_hat)**2)
self.mean.update(4)
return loss
def validation_epoch_end(self, outputs):
mean_val = self.mean.compute()
print(f"\nmean_val is : {mean_val}\n")
self.mean.reset()
def forward(self,x):
return self.conv(x)
if __name__=='__main__':
dataset = DummyDataset()
train_loader=DataLoader(dataset,batch_size=4,num_workers=0)
val_loader=DataLoader(dataset,batch_size=4,num_workers=0)
model = DummyModel()
# We create trainer
trainer = pl.Trainer(val_check_interval=None)
# We fit model
trainer.fit(model,train_dataloaders=train_loader,val_dataloaders=val_loader)
What i see in the output is :
It is coherent with what i see with the debugger and the order is :
Is it the case ?
Did i something wrong ?
How does it work ?
Thanks !
The sequence you observe is correct. Here is a sketch of how it is implemented:
for epoch in range(max_epocks):
for i, batch in enumerate(train_dataloader):
model.training_step(batch, i)
if should_validate():
for i, batch in enumerate(val_dataloader):
model.validation_step(i, batch)
model.validation_epoch_end()
model.training_epoch_end()
As you can see, the validation loop is inside the training loop, and can potentially trigger on a batch level. This is can be configured in the Trainer via Trainer(val_check_interval=x)
where x means every x batches.
But by default it will validate every epoch, which means every len(train_dataloader)
, and thus the should_validate
condition is true on the very last batch of the epoch. This is why you see in your prints:
val_epoch_end()
train_epoch_end()
(they basically happen at the same time).
I hope this explanation helps.