I looked through different implementations of BERT's Masked Language Model. For pre-training there are two common versions:
class LMPrediction(nn.Module):
def __init__(self, hidden_size, vocab_size):
super().__init__()
self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
self.bias = nn.Parameter(torch.zeros(vocab_size))
self.decoder.bias = self.bias
def forward(self, x):
return self.decoder(x)
class LMPrediction(nn.Module):
def __init__(self, hidden_size, vocab_size, embeddings):
super().__init__()
self.decoder = nn.Linear(hidden_size, vocab_size, bias = False)
self.bias = nn.Parameter(torch.zeros(vocab_size))
self.decoder.weight = embeddings.weight ## <- THIS LINE
self.decoder.bias = self.bias
def forward(self, x):
return self.decoder(x)
Which one is correct? Mostly, I see the first implementation. However, the second one makes sense as well - but I cannot find it mentioned in any papers (I would like to see if the second version is somehow superior to the first one)
For those who are interested, it is called weight tying or joint input-output embedding. There are two papers that argue for the benefit of this approach: