Search code examples
huggingface-transformersbert-language-modelattention-modelself-attentionmultihead-attention

How to read a BERT attention weight matrix?


I have extracted from the last layer and the last attention head of my BERT model the attention score/weights matrix. However I am not too sure how to read them. The matrix is the following one. I tried to find some more information in the literature but it was not successful. Any insights? Since the matrix is not symmetric and each rows sums to 1, I am confused. Thanks a lot !

enter image description here

  tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert')
  inputs = tokenizer(input_text, return_tensors='pt') 
  attention_mask=inputs['attention_mask']
  outputs = model(inputs['input_ids'],attention_mask) #len 30 as the model layers #outpus.attentions
  attention = outputs[-1]
attention = attention[-1] #last layer attention
layer_attention = layer_attention[-1] #last head attention
#... code to read it as a matrix with token labels

Solution

  • The attention matrix is asymmetric because query and key matrices differ.

    At its core (leaving normalization constants and the multi-head trick aside) (dot-product) self-attention is computed as follows:

    1. Compute key-query affinities (e_ij): given (T being the sequence length, q_i and k_j being query and key vectors)

    2. Compute attention weights from affinities (alpha_ij):

    As you can see, you get the normalization of the affinities by summing over all keys given a query; said differently, in the denominator you're summing affinities by row (thus, probabilities sum to 1 over rows).

    The way you should read the attention matrix is the following: row tokens (queries) attend to column tokens (keys) and the matrix weights represent a way to probabilistically measure where attention is directed to when querying over keys (i.e. to which key - and so to which token of the sentence - each query (token) mainly focuses to). Such interaction is unidirectional (you might look at each query as looking for information somewhere in the keys, the opposite interaction being irrelevant). I found the interpretation of the attention matrix as a directed graph within this blogpost very effective.

    Eventually, I'd also suggest the first BertViz medium post which distinguishes different attention patterns and according to which your example would fall in the case where attention is mostly directed to the delimiter token [CLS].