Search code examples
pythondeep-learningpytorchtransfer-learningpytorch-lightning

How to make pytorch lightning module have injected, nested models?


I have some nets, such as the following (augmented) resnet18:

num_classes = 10
resnet = models.resnet18(pretrained=True)
for param in resnet.parameters():
    param.requires_grad = True
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, num_classes)

And I want to use them inside a lightning module, and have it handle all optimizations, to_device, stages and so on. In other words, I want to register those modules for my lightning module. I also want to be able to access their public members.

class MyLightning(LightningModule):
    def __init__(self, resnet):
        super().__init__()
        self._resnet = resnet
        self._criterion = lambda x: 1.0

    def forward(self, x):
        resnet_out = self._resnet(x)
        
        loss =  self._criterion(resnet_out)
        return loss


my_lightning = MyLightning(resnet)

The above doesn't optimize any parameters.

Trying

def __init__(self, resnet)
    ...
    _layers = list(resnet.children())[:-1]
    self._resnet = nn.Sequential(*_layers)

Doesn't take resnet.fc into account. This also doesn't make sense to be the intended way of nesting models inside pytorch lightning.


How to nest models in pytorch lightning, and have them fully accessible and handled by the framework?


Solution

  • The training loop and optimization process is handles by the Trainer class. You can do so by initializing a new instance:

    >>> trainer = Trainer()
    

    And wrapping your PyTorch Lightning module with it. This way you can perform fitting, tuning, validating, and testing on that instance provided a DataLoader or LightningDataModule:

    >>> trainer.fit(my_lightning, train_dataloader, val_dataloader)
    

    You will have to implement the following functions on your Lightning module (i.e. in your case MyLightning):

    Name Description
    init Define computations here
    forward Use for inference only (separate from training_step)
    training_step the complete training loop
    validation_step the complete validation loop
    test_step the complete test loop
    predict_step the complete prediction loop
    configure_optimizers define optimizers and LR schedulers

    source LightningModule documentation page.

    Keep in mind a LightningModule is a nn.Module, so whenever you define a nn.Module as attribute to a LightningModule in the __init__ function, this module will end being registered as a sub-module to the parent pytorch lightning module.