Search code examples
kerasdeep-learningnlprecurrent-neural-networkattention-model

How to visualize attention weights?


Using this implementation I have included attention to my RNN (which classify the input sequences into two classes) as follows.

visible = Input(shape=(250,))

embed=Embedding(vocab_size,100)(visible)

activations= keras.layers.GRU(250, return_sequences=True)(embed)

attention = TimeDistributed(Dense(1, activation='tanh'))(activations) 
attention = Flatten()(attention)
attention = Activation('softmax')(attention)
attention = RepeatVector(250)(attention)
attention = Permute([2, 1])(attention) 

sent_representation = keras.layers.multiply([activations, attention])
sent_representation = Lambda(lambda xin: K.sum(xin, axis=1))(sent_representation)
predictions=Dense(1, activation='sigmoid')(sent_representation)

model = Model(inputs=visible, outputs=predictions)

I have trained the model and saved the weights into weights.best.hdf5 file.

I am dealing with binary classification problem and the input to my model is the one hot vectors (character based).

How can I visualize the attention weights for certain specific test case in the current implementation?


Solution

  • Visualizing attention is not complicated but you need some tricks. While constructing the model you need to give a name to your attention layer.

    (...)
    attention = keras.layers.Activation('softmax', name='attention_vec')(attention)
    (...)
    

    On loading saved model you need to get the attention layer output on predict.

    model = load_model("./saved_model.h5")
    model.summary()
    model = Model(inputs=model.input,
                  outputs=[model.output, model.get_layer('attention_vec').output])
    

    Now you can get the output of model and also the attention vector.

    ouputs = model.predict(encoded_input_text)
    model_outputs = outputs[0]
    attention_outputs = outputs[1]
    

    There are lots of visualising approach of attention vector. Basically the attention output is a softmax output and they are between 0 and 1. You can change these values to rgb codes. If you are working on a Jupyter notebook this following snippet helps you to understand consept and visualise:

    class CharVal(object):
        def __init__(self, char, val):
            self.char = char
            self.val = val
    
        def __str__(self):
            return self.char
    
    def rgb_to_hex(rgb):
        return '#%02x%02x%02x' % rgb
    def color_charvals(s):
        r = 255-int(s.val*255)
        color = rgb_to_hex((255, r, r))
        return 'background-color: %s' % color
    
    # if you are using batches the outputs will be in batches
    # get exact attentions of chars
    an_attention_output = attention_outputs[0][-len(encoded_input_text):]
    
    # before the prediction i supposed you tokenized text
    # you need to match each char and attention
    char_vals = [CharVal(c, v) for c, v in zip(tokenized_text, attention_output)]
    import pandas as pd
    char_df = pd.DataFrame(char_vals).transpose()
    # apply coloring values
    char_df = char_df.style.applymap(color_charvals)
    char_df
    

    To summarize you need to get attention outputs from model, match outputs with inputs and convert them rgb or hex and visualise. I hope it was clear.