Search code examples
pythonpytorchartificial-intelligence

Do I need to load the weights of another class I use in my NN class?


I have a model that needs to implement self-attention and this is how I wrote my code:

class SelfAttention(nn.Module):
    def __init__(self, args):
        self.multihead_attn = torch.nn.MultiheadAttention(args)
        
    def foward(self, x):
        return self.multihead_attn.forward(x, x, x)
    
class ActualModel(nn.Module):
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
    
    def forward(self, x):
        x = self.inp_layer(x)
        x = self.self_attention(x)
        x = self.out_layer(x)
        return x

After loading a checkpoint of ActualModel, in ActualModel.__init__ during continuing-training or during prediction time should I load a saved model checkpoint of class SelfAttention?

If I create an instance of class SelfAttention, would the trained weights corresponding to SelfAttention.multihead_attn be loaded if I do torch.load(actual_model.pth) or would be they be reinitialized?

In other words, is this necessary?

class ActualModel(nn.Module):
    
    def __init__(self):
        self.inp_layer = nn.Linear(arg1, arg2)
        self.self_attention = SelfAttention(some_args)
        self.out_layer = nn.Linear(arg2, 1)
        
    def pred_or_continue_train(self):
        self.self_attention = torch.load('self_attention.pth')

actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()

Solution

  • In other words, is this necessary?

    In short, No.

    The SelfAttention class will be automatically loaded if it has been registered as a nn.module, nn.Parameters, or manually registered buffers.

    A quick example:

    import torch
    import torch.nn as nn
    
    class SelfAttention(nn.Module):
        def __init__(self, fin, n_h):
            super(SelfAttention, self).__init__()
            self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
            
        def foward(self, x):
            return self.multihead_attn.forward(x, x, x)
        
    class ActualModel(nn.Module):
        def __init__(self):
            super(ActualModel, self).__init__()
            self.inp_layer = nn.Linear(10, 20)
            self.self_attention = SelfAttention(20, 1)
            self.out_layer = nn.Linear(20, 1)
        
        def forward(self, x):
            x = self.inp_layer(x)
            x = self.self_attention(x)
            x = self.out_layer(x)
            return x
    
    m = ActualModel()
    for k, v in m.named_parameters():
        print(k)
    

    You will get as follows, where self_attention is successfully registered.

    inp_layer.weight
    inp_layer.bias
    self_attention.multihead_attn.in_proj_weight
    self_attention.multihead_attn.in_proj_bias
    self_attention.multihead_attn.out_proj.weight
    self_attention.multihead_attn.out_proj.bias
    out_layer.weight
    out_layer.bias