What is the correct way of sharing weights between two layers(modules) in Pytorch?
Based on my findings in the Pytorch discussion forum, there are several ways for doing this.
As an example, based on this discussion, I thought simply assigning the transposed weights would do it. That is doing :
self.decoder[0].weight = self.encoder[0].weight.t()
This however, proved to be wrong and causes an error.
I then tried wrapping the above line in a nn.Parameter()
:
self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t())
This eliminates the error, but then again, there is no sharing happening here. by this I just initialized a new tensor with the same values as the encoder[0].weight.t()
.
I then found this link which provides different ways for sharing weights. however, I'm skeptical if all methods given there are actually correct.
For example, one way is demonstrated like this :
# tied autoencoder using off the shelf nn modules
class TiedAutoEncoderOffTheShelf(nn.Module):
def __init__(self, inp, out, weight):
super().__init__()
self.encoder = nn.Linear(inp, out, bias=False)
self.decoder = nn.Linear(out, inp, bias=False)
# tie the weights
self.encoder.weight.data = weight.clone()
self.decoder.weight.data = self.encoder.weight.data.transpose(0,1)
def forward(self, input):
encoded_feats = self.encoder(input)
reconstructed_output = self.decoder(encoded_feats)
return encoded_feats, reconstructed_output
Basically it creates a new weight tensor using nn.Parameter()
and assigns it to each layer/module like this :
weights = nn.Parameter(torch.randn_like(self.encoder[0].weight))
self.encoder[0].weight.data = weights.clone()
self.decoder[0].weight.data = self.encoder[0].weight.data.transpose(0, 1)
This really confuses me, how is this sharing the same variable between these two layers?
Is it not just cloning the 'raw' data?
When I used this approach, and visualized the weights, I noticed the visualizations were different and that make me even more certain something is not right.
I'm not sure if the different visualizations were solely due to one being the transpose of the other one, or as I just already suspected, they are optimized independently (i.e. the weights are not shared between layers)
As it turns out, after further investigation, which was simply retransposing the decoder's weight and visualizing it, they were indeed shared.
Below is the visualization for encoder and decoders weights :
The non-functional form, does indeed share the weights, it does so by sharing/linking the underlying storage using the .data
trick.
This effectively shares the underlying storage/memory between the encoder and the decoder and therefore any changes to one will be reflected in the other one. Note that transposing (.t()
) creates a temporary view, and it does not change the underlying storage!
There is a catch here however, assigning to the .data
or directly alter the underlying storage, bypasses the Pytorch's autograd system, and causes it to not be able to track this operation and therefor track the gradients. This will in-turn make the gradients for the shared weight(weights
) to be None!
However, this does not pose any issues in our case, as the 'grad' properties for encoder and decoder will have the needed gradients during back-propagation(they will be populated as the weights/biases are parameters of the linear layer as well). and like that, encoder and decoder apply the changes to the shared weights and the final result will come out perfectly fine.
These nuisances don't exist when we use the functional form, therefore it will be much safer IMHO to use the functional form especially when the architecture is more complex! aside from that, using the .data
trick works as well but with a bit of nuisance!
Here's a simple minimal example that demonstrates this:
import torch
import torch.nn as nn
import torch.nn.functional as F
class SharedWeightsAE(nn.Module):
def __init__(self, input_dim=4, embedding_dim=2):
super().__init__()
self.encoder = nn.Linear(input_dim,embedding_dim)
self.decoder = nn.Linear(embedding_dim,input_dim)
# define a single weight and assign it to both encoder and decoder
self.shared_weight = nn.Parameter(torch.randn(embedding_dim, input_dim))
# note we use .data to directly access the underlying storage and link
# shared weight parameter's underlying storage with encoder/decoder's together
# note that, by doing this, we are bypassing pytorchs autograd system and therefore pytorch will not be able to track gradients here
# but this doesnt pose an issue for us, as the grad property for each module will
# be populated properly during training (though the shared_weight wont have any gradients
# for this reason, but since the underlying storage is linked, the changes will take
# place in the same storage and everything will be fine
self.encoder.weight.data = self.shared_weight
self.decoder.weight.data = self.shared_weight.t()
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
# heres the functional version
class SharedWeightsAEFunctional(nn.Module):
def __init__(self, input_dim=4, embedding_dim=2):
super().__init__()
# a single weight parameter is used for both encoder and decoder
self.shared_weight = nn.Parameter(torch.randn(embedding_dim, input_dim))
# since we use the functional form of linear layer,
# we also prepare a separate bias parameter for
# the encoder and decoder as well(they are not shared obviously!)
self.encoder_bias = nn.Parameter(torch.zeros(embedding_dim))
self.decoder_bias = nn.Parameter(torch.zeros(input_dim))
# instead of a module, we now create a method to easily call them
# just like the previous version
def encoder(self, x):
return F.linear(x, self.shared_weight, self.encoder_bias)
def decoder(self, x):
return F.linear(x, self.shared_weight.t(), self.decoder_bias)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return encoded, decoded
torch.manual_seed(5)
def main(use_functional=True):
print('-'*40)
print(f"Using {'Functional' if use_functional else 'Non-Functional'} Form")
if use_functional:
model = SharedWeightsAEFunctional(input_dim=4, embedding_dim=2)
else:
model = SharedWeightsAE(input_dim=4, embedding_dim=2)
# our input
x = torch.randn(3, 4)
# forward pass
_, decoded = model(x)
encoders_weight = model.shared_weight if use_functional else model.encoder.weight
decoders_weight = model.shared_weight.t() if use_functional else model.decoder.weight
print('\nBefore update:')
print(f'Encoders Weight:\n {encoders_weight.detach().numpy()}')
# note that since transposing(calling .t()) creates a temporary view
# the id() will be different (values order are obviously different because
# the shape is different after transposing!) so to better show that the
# underlying data is indeed the same, we transpose it back!
# to get the same view as the original shared_weight used by encoder
print(f'Decoders Weight(transposed):\n {decoders_weight.t().detach().numpy()}')
# now lets update the shared weight directly!
# this should reflect in both the encoder and decoder weights
model.shared_weight.data += 1.0
# model.encoder.weight.data += 1.0
# model.decoder.weight.data += 1.0
print('\nAfter direct update:')
print(f'Encoders weight:\n {encoders_weight.detach().numpy()}')
print(f'Decoders weight(transposed):\n {decoders_weight.t().detach().numpy()}')
# Heres another check to make sure they all match!
assert torch.eq(encoders_weight, decoders_weight.t()).all(),'They must match!'
# lets see how gradients are affected/properly accumulated
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
loss = F.mse_loss(decoded, x)
loss.backward()
print('\nGradient check:')
# shared_weight only has grads when using functional form,
# in nonfunctional form its grads are None!
print(f'shared_weight Gradients:\n{model.shared_weight.grad}')
if not use_functional:
# in nonfunctional form, the gradients are accumulated properly for
# respective parameters as they are part of linear layer and autograd
# system handles it normally
print(f'Encoder Gradients:\n{encoders_weight.grad}')
print(f'Decoder Gradients:\n{decoders_weight.grad.t()}')
# now lets take one sgd step and see how the shared weights
# are affected. this shows us whether they are truly shared or not!
optimizer.step()
print('\nAfter the optimizer update:')
print(f'Encoders weight:\n {encoders_weight.detach().numpy()}')
print(f'Decoders weight(transposed):\n {decoders_weight.t().detach().numpy()}')
print(f'Weight Norms:')
print(f' shared_weight: {model.shared_weight.norm()}')
print(f' encoders_weight: {encoders_weight.norm()}')
print(f' decoders_weight: {decoders_weight.t().norm()}')
# Heres another check to make sure they all match!
assert torch.eq(encoders_weight, decoders_weight.t()).all(),'They must match!'
# note the difference in param count between the two methods
# this is another of those nuisaunses we face when we bypass the autograd system!
print(f'\nmodel param count: {sum(p.numel() for p in model.parameters()):,}')
for name,param in model.named_parameters():
print(f'{name}:{id(param)} {tuple(param.shape)}')
main(use_functional=True)
main(use_functional=False)
should result in sth like this:
----------------------------------------
Using Functional Form
Before update:
Encoders Weight:
[[-0.48678076 -0.6038216 -0.55809623 0.6675243 ]
[-0.1974151 1.9427835 -1.401702 -0.76255715]]
Decoders Weight(transposed):
[[-0.48678076 -0.6038216 -0.55809623 0.6675243 ]
[-0.1974151 1.9427835 -1.401702 -0.76255715]]
After direct update:
Encoders weight:
[[ 0.51321924 0.39617842 0.44190377 1.6675243 ]
[ 0.8025849 2.9427834 -0.40170205 0.23744285]]
Decoders weight(transposed):
[[ 0.51321924 0.39617842 0.44190377 1.6675243 ]
[ 0.8025849 2.9427834 -0.40170205 0.23744285]]
Gradient check:
shared_weight Gradients:
tensor([[ 0.1537, -2.1442, 1.1304, 1.1856],
[-2.6236, 9.8073, -3.8072, -7.1089]])
After the optimizer update:
Encoders weight:
[[ 0.51168174 0.41761994 0.43059948 1.6556679 ]
[ 0.8288213 2.8447099 -0.36362973 0.30853203]]
Decoders weight(transposed):
[[ 0.51168174 0.41761994 0.43059948 1.6556679 ]
[ 0.8288213 2.8447099 -0.36362973 0.30853203]]
Weight Norms:
shared_weight: 3.5170462131500244
encoders_weight: 3.5170462131500244
decoders_weight: 3.5170462131500244
model param count: 14
shared_weight:126809411716512 (2, 4)
encoder_bias:126809411715712 (2,)
decoder_bias:126809411715152 (4,)
----------------------------------------
Using Non-Functional Form
Before update:
Encoders Weight:
[[ 0.7324532 -0.3460281 -2.4620266 0.87129027]
[ 0.89674085 0.31359908 -0.32158604 -0.78190976]]
Decoders Weight(transposed):
[[ 0.7324532 -0.3460281 -2.4620266 0.87129027]
[ 0.89674085 0.31359908 -0.32158604 -0.78190976]]
After direct update:
Encoders weight:
[[ 1.7324532 0.6539719 -1.4620266 1.8712902 ]
[ 1.8967409 1.3135991 0.678414 0.21809024]]
Decoders weight(transposed):
[[ 1.7324532 0.6539719 -1.4620266 1.8712902 ]
[ 1.8967409 1.3135991 0.678414 0.21809024]]
Gradient check:
shared_weight Gradients:
None
Encoder Gradients:
tensor([[-5.1886, -4.1419, -7.3184, -1.5923],
[-0.9814, -0.2428, -0.8761, -0.8501]])
Decoder Gradients:
tensor([[ 1.8538, 0.1696, -5.4423, 2.4453],
[-0.3099, 0.0056, 1.0833, -0.4265]])
After the optimizer update:
Encoders weight:
[[ 1.7658013 0.6936956 -1.3344194 1.8627597]
[ 1.9096545 1.3159711 0.6763415 0.2308565]]
Decoders weight(transposed):
[[ 1.7658013 0.6936956 -1.3344194 1.8627597]
[ 1.9096545 1.3159711 0.6763415 0.2308565]]
Weight Norms:
shared_weight: 3.8391547203063965
encoders_weight: 3.8391547203063965
decoders_weight: 3.8391547203063965
model param count: 30
shared_weight:126809411724432 (2, 4)
encoder.weight:126809423103552 (2, 4)
encoder.bias:126812602040960 (2,)
decoder.weight:126809411715152 (4, 2)
decoder.bias:126809411727952 (4,)