Search code examples
pythontorch

How does pytorch Module collect learnable parameters from modules in its attributes?


When I define a class as a submodule of torch.nn.Module and then I define some class attributes, such as

class Vgg16(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.feature_1 = nn.Sequential()  
    self.classifier = nn.Sequential()
    ...
    my_weight = self.state_dict()

Does the my_weight variable contain state_dict including the state of the nn.Sequential() modules? I believe the state_dict contains all the parameters required for module reconstruction, but I have no ide how does the module register them when they are being created.

The constructor of the Sequential module has no way of knowing that it is instantiated inside of another module, or has it?

I would understand if it was done through the torch.nn.Module.add_module(...) but here it is not. I know that the Module has some private dict of modules and overwrites the __getattr__() method so that I can access layers (submodules) as attributes, but how does it work when calling for the state_dict()?


Solution

  • So I found out that it happens inside the torch.nn.Module class, where the def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None: is overloaded. See the source code. Each time you do a self.attribute = something assignment, it will check if the assigned object is a torch.nn.parameter.Parameter instance and if so, it will register it automatically.

    It also checks for names that clash with already registered Parameters.

    Both the above applies for assigned object instances of type nn.Module.

    It also check if the attribute name clashes with any registered buffers, but here it actually does not check the instance type, since there are only three types of objects the nn.Module tracks in the three dictionaries (_modules, _parameters, _buffers)