I would like two torch.nn.Module classes to share part of their architecture and weights, as in the example below:
from torch import nn
class SharedBlock(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.block = nn.Sequential(
# Define some block architecture here...
)
def forward(self, x):
return self.block(x)
class MyNestedModule(nn.Module):
def __init__(self, shared_block: nn.Module, *args, **kwargs):
super().__init__()
self.linear = nn.Linear(...)
self.shared_block = shared_block
def forward(self, x):
return self.shared_block(self.linear(x))
class MyModule(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
# SHOULD THIS BE:
shared_block = SharedBlock(*args, **kwargs)
# OR:
self.shared_block = SharedBlock(*args, **kwargs) # Note: self.
# ...AND WHAT IS THE DIFFERENCE, IF ANY?
self.nested1 = MyNestedModule(shared_block, *args, **kwargs)
self.nested2 = MyNestedModule(shared_block, *args, **kwargs)
def forward(self, x):
x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
y_1 = self.nested1(x_1)
y_2 = self.nested2(y_2)
return y_1, y_2
I would like to know whether shared_block
should be an object parameter of MyModule
. I assume it does not, since it is set as an object parameter in both the MyNestedModule
class objects so it should be registered in torch grad but if I did create it as an object parameter in MyModule
what would happen?
It doesn't matter, the parameters are tracked both ways. If you use shared_block = ...
, the parameters in shared_block
will be referenced in your state dict (model.state_dict()
) twice, once for self.nested1
and again for self.nested2
.
If you use the self.shared_block = ...
approach, the state dict will reference the parameters a third time in MyModule
itself.
Either way, the parameters are tracked and model.parameters()
will return a non-duplicated set of parameters.
You can run this code to look at a simplified version
import torch
from torch import nn
class SharedBlock(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Linear(8, 8)
def forward(self, x):
return self.block(x)
class MyNestedModule(nn.Module):
def __init__(self, shared_block):
super().__init__()
self.shared_block = shared_block
def forward(self, x):
return self.shared_block(x)
class MyModule1(nn.Module):
def __init__(self):
super().__init__()
shared_block = SharedBlock()
self.nested1 = MyNestedModule(shared_block)
self.nested2 = MyNestedModule(shared_block)
def forward(self, x):
x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
y_1 = self.nested1(x_1)
y_2 = self.nested2(x_2)
return y_1, y_2
class MyModule2(nn.Module):
def __init__(self):
super().__init__()
self.shared_block = SharedBlock()
self.nested1 = MyNestedModule(self.shared_block)
self.nested2 = MyNestedModule(self.shared_block)
def forward(self, x):
x_1, x_2 = torch.split(x, x.shape[0] // 2, dim=0)
y_1 = self.nested1(x_1)
y_2 = self.nested2(x_2)
return y_1, y_2
model1 = MyModule1()
print(model1.state_dict())
print(list(model1.parameters()))
model2 = MyModule2()
print(model2.state_dict())
print(list(model2.parameters()))