Search code examples
pythonpytorchgradient-descentembeddingword-embedding

Freeze only some lines of a torch.nn.Embedding object


I am quite a newbie to Pytorch, and I am trying to implement a sort of "post-training" procedure on embeddings.

I have a vocabulary with a set of items, and I have learned one vector for each of them. I keep the learned vectors in a nn.Embedding object. What I'd like to do now is to add a new item to the vocabulary without updating the already learned vectors. The embedding for the new item would be initialized randomly, and then trained while keeping all the other embeddings frozen.

I know that in order to prevent a nn.Embedding to be trained, I need to set to False its requires_grad variable. I have also found this other question that is similar to mine. The best answer proposes to

  1. either store the frozen vectors and the vector to train in different nn.Embedding objects, the former with requires_grad = False and the latter with requires_grad = True

  2. or store the frozen vectors and the new one in the same nn.Embedding object, computing the gradient on all vectors, but descending it is only on the dimensions of the vector of of the new item. This, however, leads to a relevant degradation in performances (which I want to avoid, of course).

My problem is that I really need to store the vector for the new item in the same nn.Embedding object as the frozen vectors of the old items. The reason for this constraint is the following: when building my loss function with the embeddings of the items (old and new), I need to lookup the vectors based on the ids of the items, and for performances reasons I need to use Python slicing. In other words, given a list of item ids item_ids, I need to do something like vecs = embedding[item_ids]. If I used two different nn.Embedding items for the old items and the and new one I would need to use an explicit for-loop with if-else conditions, which would lead to worse performances.

Is there any way I can do this?


Solution

  • If you look at the implementation of nn.Embedding it uses the functional form of embedding in the forward pass. Therefore, I think you could implement a custom module that does something like this:

    import torch
    from torch.nn.parameter import Parameter
    import torch.nn.functional as F
    
    weights_freeze = torch.rand(10, 5)  # Don't make parameter
    weights_train = Parameter(torch.rand(2, 5))
    weights = torch.cat((weights_freeze, weights_train), 0)
    
    idx = torch.tensor([[11, 1, 3]])
    lookup = F.embedding(idx, weights)
    
    # Desired result
    print(lookup)
    lookup.sum().backward()
    # 11 corresponds to idx 1 in weights_train so this has grad
    print(weights_train.grad)