Search code examples
parameterspytorchlookupembedding

Realtion between model.parameters() and nn.Embedding()


I have the code here


import torch


class Model(torch.nn.Module):
    def __init__(self, num, dim):
        super(Model, self).__init__()
        self.node_embed = torch.nn.Embedding(num, dim)

def update_node_embed(B):

    first_ten_embed = B[0:10] # it raise error that B is not callable
    first_ten_embed *= 0.1


A = model(1000, 16)
B = list(A.parameters())

As we know B will contain the 'weight' of self.node_embed. The function update_node_embed will take B as input, and I need to update the lookup table self.node_embed.

My question is that How can I get self.node_embed from B? I try but it told me that B is parameters and is not callable.


Solution

  • The first parameter within A.parameters() is your embedding. Here is your update_node_embed():

    def update_node_embed(B):
        embed_params = B[0]
        first_ten_embed = embed_params[:10]
        with torch.no_grad(): # Avoid in-place operation on a leaf variable
            first_ten_embed *= 0.1