Search code examples
pythonpytorch

How to share weights between modules in Pytorch?


What is the correct way of sharing weights between two layers(modules) in Pytorch?
Based on my findings in the Pytorch discussion forum, there are several ways for doing this.
As an example, based on this discussion, I thought simply assigning the transposed weights would do it. That is doing :

 self.decoder[0].weight = self.encoder[0].weight.t()

This however, proved to be wrong and causes an error. I then tried wrapping the above line in a nn.Parameter():

self.decoder[0].weight = nn.Parameter(self.encoder[0].weight.t())

This eliminates the error, but then again, there is no sharing happening here. by this I just initialized a new tensor with the same values as the encoder[0].weight.t().

I then found this link which provides different ways for sharing weights. however, I'm skeptical if all methods given there are actually correct.
For example, one way is demonstrated like this :

# tied autoencoder using off the shelf nn modules
class TiedAutoEncoderOffTheShelf(nn.Module):
    def __init__(self, inp, out, weight):
        super().__init__()
        self.encoder = nn.Linear(inp, out, bias=False)
        self.decoder = nn.Linear(out, inp, bias=False)

        # tie the weights
        self.encoder.weight.data = weight.clone()
        self.decoder.weight.data = self.encoder.weight.data.transpose(0,1)

    def forward(self, input):
        encoded_feats = self.encoder(input)
        reconstructed_output = self.decoder(encoded_feats)
        return encoded_feats, reconstructed_output

Basically it creates a new weight tensor using nn.Parameter() and assigns it to each layer/module like this :

weights = nn.Parameter(torch.randn_like(self.encoder[0].weight))
self.encoder[0].weight.data = weights.clone()
self.decoder[0].weight.data = self.encoder[0].weight.data.transpose(0, 1)

This really confuses me, how is this sharing the same variable between these two layers? Is it not just cloning the 'raw' data?
When I used this approach, and visualized the weights, I noticed the visualizations were different and that make me even more certain something is not right.
I'm not sure if the different visualizations were solely due to one being the transpose of the other one, or as I just already suspected, they are optimized independently (i.e. the weights are not shared between layers)

example weight initialization : enter image description here enter image description here


Solution

  • TLDR:‌

    As it turns out, after further investigation, which was simply retransposing the decoder's weight and visualizing it, they were indeed shared.
    Below is the visualization for encoder and decoders weights : enter image description here enter image description here

    More explanation:

    The non-functional form, does indeed share the weights, it does so by sharing/linking the underlying storage using the .data trick.
    This effectively shares the underlying storage/memory between the encoder and the decoder and therefore any changes to one will be reflected in the other one. Note that transposing (.t()) creates a temporary view, and it does not change the underlying storage!

    There is a catch here however, assigning to the .data or directly alter the underlying storage, bypasses the Pytorch's autograd system, and causes it to not be able to track this operation and therefor track the gradients. This will in-turn make the gradients for the shared weight(weights) to be None!

    However, this does not pose any issues in our case, as the 'grad' properties for encoder and decoder will have the needed gradients during back-propagation(they will be populated as the weights/biases are parameters of the linear layer as well). and like that, encoder and decoder apply the changes to the shared weights and the final result will come out perfectly fine.

    These nuisances don't exist when we use the functional form, therefore it will be much safer IMHO to use the functional form especially when the architecture is more complex! aside from that, using the .data trick works as well but with a bit of nuisance!

    Here's a simple minimal example that demonstrates this:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    class SharedWeightsAE(nn.Module):
        def __init__(self, input_dim=4, embedding_dim=2):
            super().__init__()
            self.encoder = nn.Linear(input_dim,embedding_dim)
            self.decoder = nn.Linear(embedding_dim,input_dim)
            # define a single weight and assign it to both encoder and decoder
            self.shared_weight = nn.Parameter(torch.randn(embedding_dim, input_dim))
            # note we use .data to directly access the underlying storage and link
            # shared weight parameter's underlying storage with encoder/decoder's together
            # note that, by doing this, we are bypassing pytorchs autograd system and therefore pytorch will not be able to track gradients here
            # but this doesnt pose an issue for us, as the grad property for each module will
            # be populated properly during training (though the shared_weight wont have any gradients
            # for this reason, but since the underlying storage is linked, the changes will take
            # place in the same storage and everything will be fine
            self.encoder.weight.data = self.shared_weight
            self.decoder.weight.data = self.shared_weight.t()
            
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
    
    # heres the functional version
    class SharedWeightsAEFunctional(nn.Module):
        def __init__(self, input_dim=4, embedding_dim=2):
            super().__init__()
            # a single weight parameter is used for both encoder and decoder
            self.shared_weight = nn.Parameter(torch.randn(embedding_dim, input_dim))
            # since we use the functional form of linear layer, 
            # we also prepare a separate bias parameter for 
            # the encoder and decoder as well(they are not shared obviously!)
            self.encoder_bias = nn.Parameter(torch.zeros(embedding_dim))
            self.decoder_bias = nn.Parameter(torch.zeros(input_dim))
    
        # instead of a module, we now create a method to easily call them
        # just like the previous version
        def encoder(self, x):
            return F.linear(x, self.shared_weight, self.encoder_bias)
    
        def decoder(self, x):
            return F.linear(x, self.shared_weight.t(), self.decoder_bias)
    
        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return encoded, decoded
    
    torch.manual_seed(5)
    def main(use_functional=True):
        print('-'*40)
        print(f"Using {'Functional' if use_functional else 'Non-Functional'} Form")
    
        if use_functional:
            model = SharedWeightsAEFunctional(input_dim=4, embedding_dim=2) 
        else:
            model = SharedWeightsAE(input_dim=4, embedding_dim=2)
    
        # our input
        x = torch.randn(3, 4)
    
        # forward pass
        _, decoded = model(x)
    
        encoders_weight = model.shared_weight if use_functional else model.encoder.weight
        decoders_weight = model.shared_weight.t() if use_functional else model.decoder.weight
    
        print('\nBefore update:')
        print(f'Encoders Weight:\n {encoders_weight.detach().numpy()}')
        # note that since transposing(calling .t()) creates a temporary view
        # the id() will be different (values order are obviously different because
        # the shape is different after transposing!) so to better show that the 
        # underlying data is indeed the same, we transpose it back!
        # to get the same view as the original shared_weight used by encoder
        print(f'Decoders Weight(transposed):\n {decoders_weight.t().detach().numpy()}')
    
        # now lets update the shared weight directly!
        # this should reflect in both the encoder and decoder weights
        model.shared_weight.data += 1.0
        # model.encoder.weight.data += 1.0
        # model.decoder.weight.data += 1.0
    
        print('\nAfter direct update:')
        print(f'Encoders weight:\n {encoders_weight.detach().numpy()}')
        print(f'Decoders weight(transposed):\n {decoders_weight.t().detach().numpy()}')
        # Heres another check to make sure they all match!
        assert torch.eq(encoders_weight, decoders_weight.t()).all(),'They must match!'
    
        # lets see how gradients are affected/properly accumulated
        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        loss = F.mse_loss(decoded, x)
        loss.backward()
    
        print('\nGradient check:')
        # shared_weight only has grads when using functional form,
        # in nonfunctional form its grads are None!
        print(f'shared_weight Gradients:\n{model.shared_weight.grad}')
        if not use_functional:
            # in nonfunctional form, the gradients are accumulated properly for 
            # respective parameters as they are part of linear layer and autograd
            # system handles it normally
            print(f'Encoder Gradients:\n{encoders_weight.grad}')
            print(f'Decoder Gradients:\n{decoders_weight.grad.t()}')
            
        # now lets take one sgd step and see how the shared weights
        # are affected. this shows us whether they are truly shared or not!
        optimizer.step()
    
        print('\nAfter the optimizer update:')
        print(f'Encoders weight:\n {encoders_weight.detach().numpy()}')
        print(f'Decoders weight(transposed):\n {decoders_weight.t().detach().numpy()}')
        
        print(f'Weight Norms:')
        print(f' shared_weight:   {model.shared_weight.norm()}')
        print(f' encoders_weight: {encoders_weight.norm()}')
        print(f' decoders_weight: {decoders_weight.t().norm()}')
        
        # Heres another check to make sure they all match!
        assert torch.eq(encoders_weight, decoders_weight.t()).all(),'They must match!'
    
        # note the difference in param count between the two methods
        # this is another of those nuisaunses we face when we bypass the autograd system!
        print(f'\nmodel param count: {sum(p.numel() for p in model.parameters()):,}')
        for name,param in model.named_parameters():
            print(f'{name}:{id(param)} {tuple(param.shape)}')
    
    main(use_functional=True)
    main(use_functional=False)
    
    

    should result in sth like this:

    ----------------------------------------
    Using Functional Form
    
    Before update:
    Encoders Weight:
     [[-0.48678076 -0.6038216  -0.55809623  0.6675243 ]
     [-0.1974151   1.9427835  -1.401702   -0.76255715]]
    Decoders Weight(transposed):
     [[-0.48678076 -0.6038216  -0.55809623  0.6675243 ]
     [-0.1974151   1.9427835  -1.401702   -0.76255715]]
    
    After direct update:
    Encoders weight:
     [[ 0.51321924  0.39617842  0.44190377  1.6675243 ]
     [ 0.8025849   2.9427834  -0.40170205  0.23744285]]
    Decoders weight(transposed):
     [[ 0.51321924  0.39617842  0.44190377  1.6675243 ]
     [ 0.8025849   2.9427834  -0.40170205  0.23744285]]
    
    Gradient check:
    shared_weight Gradients:
    tensor([[ 0.1537, -2.1442,  1.1304,  1.1856],
            [-2.6236,  9.8073, -3.8072, -7.1089]])
    
    After the optimizer update:
    Encoders weight:
     [[ 0.51168174  0.41761994  0.43059948  1.6556679 ]
     [ 0.8288213   2.8447099  -0.36362973  0.30853203]]
    Decoders weight(transposed):
     [[ 0.51168174  0.41761994  0.43059948  1.6556679 ]
     [ 0.8288213   2.8447099  -0.36362973  0.30853203]]
    Weight Norms:
     shared_weight:   3.5170462131500244
     encoders_weight: 3.5170462131500244
     decoders_weight: 3.5170462131500244
    
    model param count: 14
    shared_weight:126809411716512 (2, 4)
    encoder_bias:126809411715712 (2,)
    decoder_bias:126809411715152 (4,)
    ----------------------------------------
    Using Non-Functional Form
    
    Before update:
    Encoders Weight:
     [[ 0.7324532  -0.3460281  -2.4620266   0.87129027]
     [ 0.89674085  0.31359908 -0.32158604 -0.78190976]]
    Decoders Weight(transposed):
     [[ 0.7324532  -0.3460281  -2.4620266   0.87129027]
     [ 0.89674085  0.31359908 -0.32158604 -0.78190976]]
    
    After direct update:
    Encoders weight:
     [[ 1.7324532   0.6539719  -1.4620266   1.8712902 ]
     [ 1.8967409   1.3135991   0.678414    0.21809024]]
    Decoders weight(transposed):
     [[ 1.7324532   0.6539719  -1.4620266   1.8712902 ]
     [ 1.8967409   1.3135991   0.678414    0.21809024]]
    
    Gradient check:
    shared_weight Gradients:
    None
    Encoder Gradients:
    tensor([[-5.1886, -4.1419, -7.3184, -1.5923],
            [-0.9814, -0.2428, -0.8761, -0.8501]])
    Decoder Gradients:
    tensor([[ 1.8538,  0.1696, -5.4423,  2.4453],
            [-0.3099,  0.0056,  1.0833, -0.4265]])
    
    After the optimizer update:
    Encoders weight:
     [[ 1.7658013  0.6936956 -1.3344194  1.8627597]
     [ 1.9096545  1.3159711  0.6763415  0.2308565]]
    Decoders weight(transposed):
     [[ 1.7658013  0.6936956 -1.3344194  1.8627597]
     [ 1.9096545  1.3159711  0.6763415  0.2308565]]
    Weight Norms:
     shared_weight:   3.8391547203063965
     encoders_weight: 3.8391547203063965
     decoders_weight: 3.8391547203063965
    
    model param count: 30
    shared_weight:126809411724432 (2, 4)
    encoder.weight:126809423103552 (2, 4)
    encoder.bias:126812602040960 (2,)
    decoder.weight:126809411715152 (4, 2)
    decoder.bias:126809411727952 (4,)