Search code examples
nlpmaskingtransformer-modellanguage-modelhuggingface-transformers

How to get words from output of XLNet using Transformers library


I am using Hugging Face's Transformer library to work with different NLP models. Following code does masking with XLNet. It outputs a tensor with numbers. How do I convert the output to words again?

import torch
from transformers import XLNetModel,  XLNetTokenizer, XLNetLMHeadModel

tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetLMHeadModel.from_pretrained('xlnet-base-cased')

# We show how to setup inputs to predict a next token using a bi-directional context.
input_ids = torch.tensor(tokenizer.encode("I went to <mask> York and saw the <mask> <mask> building.")).unsqueeze(0)  # We will predict the masked token
print(input_ids)

perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float)
perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token

target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float)  # Shape [1, 1, seq_length] => let's predict one token
target_mapping[0, 0, -1] = 1.0  # Our first (and only) prediction will be the last token of the sequence (the masked token)

outputs = model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping)
next_token_logits = outputs[0]  # Output has shape [target_mapping.size(0), target_mapping.size(1), config.vocab_size]

The current output I get is:

tensor([[[ -5.1466, -17.3758, -17.3392, ..., -12.2839, -12.6421, -12.4505]]], grad_fn=AddBackward0)


Solution

  • The output you have is a tensor of size 1 by 1 by vocabulary size. The meaning of the nth number in this tensor is the estimated log-odds of the nth vocabulary item. So, if you want to get out the word that the model predicts to be most likely to come in the final position (the position you specified with target_mapping), all you need to do is find the word in the vocabulary with the maximum predicted log-odds.

    Just add the following to the code you have:

    predicted_index = torch.argmax(next_token_logits[0][0]).item()
    predicted_token = tokenizer.convert_ids_to_tokens(predicted_index)
    

    So predicted_token is the token the model predicts as most likely in that position.


    Note, by default behaviour of XLNetTokenizer.encoder() adds special tokens and to the end of a string of tokens when it encodes it. The code you have given masks and predicts the final word, which, after running though tokenizer.encoder() is the special character '<cls>', which is probably not what you want.

    That is, when you run

    tokenizer.encode("I went to <mask> York and saw the <mask> <mask> building.")

    the result is a list of token ids,

    [35, 388, 22, 6, 313, 21, 685, 18, 6, 6, 540, 9, 4, 3]

    which, if you convert back to tokens (by calling tokenizer.convert_ids_to_tokens() on the above id list), you will see has two extra tokens added at the end,

    ['▁I', '▁went', '▁to', '<mask>', '▁York', '▁and', '▁saw', '▁the', '<mask>', '<mask>', '▁building', '.', '<sep>', '<cls>']

    So, if the word you are meaning to predict is 'building', you should use perm_mask[:, :, -4] = 1.0 and target_mapping[0, 0, -4] = 1.0.