Search code examples
pythondeep-learningpytorchbert-language-modelhuggingface-transformers

How to find the (Most important) responsible Words/ Tokens/ embeddings responsible for the label result of a text classification model in PyTorch


Let us suppose I have a model like:

class BERT_Subject_Classifier(nn.Module):

    def __init__(self,out_classes,hidden1=128,hidden2=32,dropout_val=0.2):
      super(BERT_Subject_Classifier, self).__init__()

      self.hidden1 = hidden1
      self.hidden2 = hidden2
      self.dropout_val = dropout_val
      self.logits = logit
      self.bert = AutoModel.from_pretrained('bert-base-uncased')
      self.out_classes = out_classes
      self.unfreeze_n = unfreeze_n # make the last n layers trainable
      
      self.dropout = nn.Dropout(self.dropout_val)
      self.relu =  nn.ReLU()
      self.fc1 = nn.Linear(768,self.hidden1)
      self.fc2 = nn.Linear(self.hidden1,self.hidden2)
      self.fc3 = nn.Linear(self.hidden2,self.out_classes)

    def forward(self, sent_id, mask):
      _, cls_hs = self.bert(sent_id, attention_mask=mask)
      x = self.fc1(cls_hs)
      x = self.relu(x)
      x = self.dropout(x)
      x = self.fc2(x)
      x = self.dropout(x)
      return self.fc3(x)

I train my model and for a new data point x = ['My Name is Slim Shady'], I get my label result as 3.

My Question is that how can I check which of the words in the sentence were responsible for the the classification? I mean it could be any collection of words. Is there a library or way to check the functionality? Just like shown in the paper and Tensorflow Implementation of show Attend and Tell, you can get the areas of images where the model is paying attention to. How can I do it for the Text?


Solution

  • Absolutely. One way to demonstrate which words have the greatest impact is through integrated gradients methods. For PyTorch, one package you can use is Captum. I would check out this page for a good example: https://captum.ai/tutorials/IMDB_TorchText_Interpret

    For Tensorflow, one package that you can use is Seldon. I would check out this page for a good example: https://docs.seldon.io/projects/alibi/en/stable/examples/integrated_gradients_imdb.html