I am quite a newbie to Pytorch, and I am trying to implement a sort of "post-training" procedure on embeddings.
I have a vocabulary with a set of items, and I have learned one vector for each of them. I keep the learned vectors in a nn.Embedding object. What I'd like to do now is to add a new item to the vocabulary without updating the already learned vectors. The embedding for the new item would be initialized randomly, and then trained while keeping all the other embeddings frozen.
I know that in order to prevent a nn.Embedding to be trained, I need to set to False
its requires_grad
variable. I have also found this other question that is similar to mine. The best answer proposes to
either store the frozen vectors and the vector to train in different nn.Embedding objects, the former with requires_grad = False
and the latter with requires_grad = True
or store the frozen vectors and the new one in the same nn.Embedding object, computing the gradient on all vectors, but descending it is only on the dimensions of the vector of of the new item. This, however, leads to a relevant degradation in performances (which I want to avoid, of course).
My problem is that I really need to store the vector for the new item in the same nn.Embedding object as the frozen vectors of the old items. The reason for this constraint is the following: when building my loss function with the embeddings of the items (old and new), I need to lookup the vectors based on the ids of the items, and for performances reasons I need to use Python slicing. In other words, given a list of item ids item_ids
, I need to do something like vecs = embedding[item_ids]
. If I used two different nn.Embedding items for the old items and the and new one I would need to use an explicit for-loop with if-else conditions, which would lead to worse performances.
Is there any way I can do this?
If you look at the implementation of nn.Embedding it uses the functional form of embedding in the forward pass. Therefore, I think you could implement a custom module that does something like this:
import torch
from torch.nn.parameter import Parameter
import torch.nn.functional as F
weights_freeze = torch.rand(10, 5) # Don't make parameter
weights_train = Parameter(torch.rand(2, 5))
weights = torch.cat((weights_freeze, weights_train), 0)
idx = torch.tensor([[11, 1, 3]])
lookup = F.embedding(idx, weights)
# Desired result
print(lookup)
lookup.sum().backward()
# 11 corresponds to idx 1 in weights_train so this has grad
print(weights_train.grad)