I am trying to access the gradient of the loss in DistilBERT with respect to each attention weight in the first layer. I could access the computed gradient value of the output weight matrix via the following code when requires_grad=True
loss.backward()
for name, param in model.named_parameters():
if name == 'transformer.layer.0.attention.out_lin.weight':
print(param.grad) #shape is [768,768]
where model
is the loaded distilbert model.
My question is how to get the gradient with respect to [SEP] or [CLS] or other tokens' attention? I need it to reproduce the figure about the "Gradient-based feature importance estimates for attention to [SEP]" in the following link:
https://medium.com/analytics-vidhya/explainability-of-bert-through-attention-7dbbab8a7062
A similar question for the same purpose has been asked in the following, but it is not my issue: BERT token importance measuring issue. Grad is none
By default, the gradients are retained only for parameters, basically just to save memory. If you need gradients of inner nodes of the computation graph, you need to have the respective tensor before calling backward()
and add a hook that will be executed at the backward pass.
A minimum solution from PyTorch forum:
yGrad = torch.zeros(1,1)
def extract(xVar):
global yGrad
yGrad = xVar
xx = Variable(torch.randn(1,1), requires_grad=True)
yy = 3*xx
zz = yy**2
yy.register_hook(extract)
#### Run the backprop:
print (yGrad) # Shows 0.
zz.backward()
print (yGrad) # Show the correct dzdy
In this case, the gradients are stored in a global variable where they persist after PyTorch get rid of them in the graph itself.