Search code examples
pythonnlpnltkspacybert-language-model

How to generate a list of tokens that are most likely to occupy the place of a missing token in a given sentence?


How to generate a list of tokens that are most likely to occupy the place of a missing token in a given sentence?

I've found this StackOverflow answer, however, this only generates a possible word, and not a list of words that fits the sentence. I tried printing out every variable to see if he might have generated all the possible words, but no luck.

For example,

>>> sentence = 'Cristiano Ronaldo dos Santos Aveiro GOIH ComM is a Portuguese professional [].' # [] is missing word
>>> generate(sentence)
['soccer', 'basketball', 'tennis', 'rugby']

Solution

  • You can essentially do the same as in this answer, but instead of adding just the best fitting token, take for example the five most fitting tokens:

    def fill_the_gaps(text):
        text = '[CLS] ' + text + ' [SEP]'
        tokenized_text = tokenizer.tokenize(text)
        indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
        segments_ids = [0] * len(tokenized_text)
        tokens_tensor = torch.tensor([indexed_tokens])
        segments_tensors = torch.tensor([segments_ids])
        with torch.no_grad():
            predictions = model(tokens_tensor, segments_tensors)
        results = []
        for i, t in enumerate(tokenized_text):
            if t == '[MASK]':
                #instead of argmax, we use argsort to sort the tokens which best fit
                predicted_index = torch.argsort(predictions[0, i], descending=True)
                tokens = []
                #the the 5 best fitting tokens and add the to the list
                for k in range(5):
                     predicted_token = tokenizer.convert_ids_to_tokens([predicted_index[k].item()])[0]
                    tokens.append(predicted_token)
                results.append(tokens)
        return results
    

    For you sentence, this results in : [['footballer', 'golfer', 'football', 'cyclist', 'boxer']]