I have a model that needs to implement self-attention and this is how I wrote my code:
class SelfAttention(nn.Module):
def __init__(self, args):
self.multihead_attn = torch.nn.MultiheadAttention(args)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
After loading a checkpoint of ActualModel
, in ActualModel.__init__
during continuing-training or during prediction time should I load a saved model checkpoint of class SelfAttention
?
If I create an instance of class SelfAttention
, would the trained weights corresponding to SelfAttention.multihead_attn
be loaded if I do torch.load(actual_model.pth)
or would be they be reinitialized?
In other words, is this necessary?
class ActualModel(nn.Module):
def __init__(self):
self.inp_layer = nn.Linear(arg1, arg2)
self.self_attention = SelfAttention(some_args)
self.out_layer = nn.Linear(arg2, 1)
def pred_or_continue_train(self):
self.self_attention = torch.load('self_attention.pth')
actual_model = torch.load('actual_model.pth')
actual_model.pred_or_continue_training()
actual_model.eval()
In other words, is this necessary?
In short, No.
The SelfAttention
class will be automatically loaded if it has been registered as a nn.module, nn.Parameters, or manually registered buffers.
A quick example:
import torch
import torch.nn as nn
class SelfAttention(nn.Module):
def __init__(self, fin, n_h):
super(SelfAttention, self).__init__()
self.multihead_attn = torch.nn.MultiheadAttention(fin, n_h)
def foward(self, x):
return self.multihead_attn.forward(x, x, x)
class ActualModel(nn.Module):
def __init__(self):
super(ActualModel, self).__init__()
self.inp_layer = nn.Linear(10, 20)
self.self_attention = SelfAttention(20, 1)
self.out_layer = nn.Linear(20, 1)
def forward(self, x):
x = self.inp_layer(x)
x = self.self_attention(x)
x = self.out_layer(x)
return x
m = ActualModel()
for k, v in m.named_parameters():
print(k)
You will get as follows, where self_attention
is successfully registered.
inp_layer.weight
inp_layer.bias
self_attention.multihead_attn.in_proj_weight
self_attention.multihead_attn.in_proj_bias
self_attention.multihead_attn.out_proj.weight
self_attention.multihead_attn.out_proj.bias
out_layer.weight
out_layer.bias