I have the code here
import torch
class Model(torch.nn.Module):
def __init__(self, num, dim):
super(Model, self).__init__()
self.node_embed = torch.nn.Embedding(num, dim)
def update_node_embed(B):
first_ten_embed = B[0:10] # it raise error that B is not callable
first_ten_embed *= 0.1
A = model(1000, 16)
B = list(A.parameters())
As we know B
will contain the 'weight' of self.node_embed
. The function update_node_embed
will take B
as input, and I need to update the lookup table self.node_embed
.
My question is that How can I get self.node_embed
from B
? I try but it told me that B
is parameters and is not callable.
The first parameter within A.parameters()
is your embedding.
Here is your update_node_embed()
:
def update_node_embed(B):
embed_params = B[0]
first_ten_embed = embed_params[:10]
with torch.no_grad(): # Avoid in-place operation on a leaf variable
first_ten_embed *= 0.1