Search code examples
pythonnlpbert-language-model

Need to Fine Tune a BERT Model to Predict Missing Words


I'm aware that BERT has a capability in predicting a missing word within a sentence, which can be syntactically correct and semantically coherent. Below is a sample code:

import torch
from pytorch_pretrained_bert import BertTokenizer, BertForMaskedLM

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.eval(); # turning off the dropout

def fill_the_gaps(text):
   text = '[CLS] ' + text + ' [SEP]'
   tokenized_text = tokenizer.tokenize(text)
   indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
   segments_ids = [0] * len(tokenized_text)
   tokens_tensor = torch.tensor([indexed_tokens])
   segments_tensors = torch.tensor([segments_ids])
   with torch.no_grad():
      predictions = model(tokens_tensor, segments_tensors)
   results = []
   for i, t in enumerate(tokenized_text):
       if t == '[MASK]':
           predicted_index = torch.argmax(predictions[0, i]).item()
           predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
           results.append(predicted_token)
   return results

 print(fill_the_gaps(text = 'I bought an [MASK] because its rainy .'))
 print(fill_the_gaps(text = 'Im sad because you are [MASK] .'))
 print(fill_the_gaps(text = 'Im worried because you are [MASK] .'))
 print(fill_the_gaps(text = 'Im [MASK] because you are [MASK] .'))

Can someone explain to me, do I need to fine Tune a BERT Model to predict missing words or just use the pre-trained BERT model? Thanks.


Solution

  • BERT is a masked Language Model, meaning it is trained on exactly this task. That is why it can do it. So in that sense, no fine tuning is needed.

    However, if the text you will see at runtime is different than the text BERT was trained on, your performance may be much better if you fine tune on the type of text you expect to see.