Search code examples
pytorchtransformer-modelattention-modelhuggingface-transformersbert-language-model

Gradient of the loss of DistilBERT for measuring token importance


I am trying to access the gradient of the loss in DistilBERT with respect to each attention weight in the first layer. I could access the computed gradient value of the output weight matrix via the following code when requires_grad=True

loss.backward()
for name, param in model.named_parameters():
    if name == 'transformer.layer.0.attention.out_lin.weight':
       print(param.grad)  #shape is [768,768]

where model is the loaded distilbert model. My question is how to get the gradient with respect to [SEP] or [CLS] or other tokens' attention? I need it to reproduce the figure about the "Gradient-based feature importance estimates for attention to [SEP]" in the following link: https://medium.com/analytics-vidhya/explainability-of-bert-through-attention-7dbbab8a7062

A similar question for the same purpose has been asked in the following, but it is not my issue: BERT token importance measuring issue. Grad is none


Solution

  • By default, the gradients are retained only for parameters, basically just to save memory. If you need gradients of inner nodes of the computation graph, you need to have the respective tensor before calling backward() and add a hook that will be executed at the backward pass.

    A minimum solution from PyTorch forum:

    yGrad = torch.zeros(1,1)
    def extract(xVar):
        global yGrad
        yGrad = xVar    
    
    xx = Variable(torch.randn(1,1), requires_grad=True)
    yy = 3*xx
    zz = yy**2
    
    yy.register_hook(extract)
    
    #### Run the backprop:
    print (yGrad) # Shows 0.
    zz.backward()
    print (yGrad) # Show the correct dzdy
    

    In this case, the gradients are stored in a global variable where they persist after PyTorch get rid of them in the graph itself.