I'm currently trying to implement the CBOW model on managed to get the training and testing, but am facing some confusion as to the "proper" way to finally extract the weights from the model to use as our word embeddings.
class CBOW(nn.Module):
def __init__(self, config, vocab):
self.config = config # Basic config file to hold arguments.
self.vocab = vocab
self.vocab_size = len(self.vocab.token2idx)
self.window_size = self.config.window_size
self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.config.embed_dim)
self.linear = nn.Linear(in_features=self.config.embed_dim, out_features=self.vocab_size)
def forward(self, x):
x = self.embed(x)
x = torch.mean(x, dim=0) # Average out the embedding values.
x = self.linear(x)
return x
After I run my model through a Solver with the training and testing data, I basically told the train
and test
functions to also return the model that's used. Then I assigned the embedding weights to a separate variable and used those as the word embeddings.
Training and testing was conducted using cross entropy loss, and each training and testing sample is of the form ([context words], target word)
.
def run(solver, config, vocabulary):
for epoch in range(config.num_epochs):
loss_train, model_train = solver.train()
loss_test, model_test = solver.test()
embeddings = model_train.embed.weight
I'm not sure if this is the correct way of going about extracting and using the embeddings. Is there usually another way to do this? Thanks in advance.
Yes, model_train.embed.weight
will give you a torch tensor that stores the embedding weights. Note however, that this tensor also contains the latest gradients. If you don't want/need them, model_train.embed.weight.data
will give you the weights only.
A more generic option is to call model_train.embed.parameters()
. This will give you a generator of all the weight tensors of the layer. In general, there are multiple weight tensors in a layer and weight
will give you only one of them. Embedding
happens to have only one, so here it doesn't matter which option you use.