I am trying to learn how to use the transformers library to make predictions on the next word given a sentence. My code always predicts a "period" as the next token. Can someone help me see what I am doing wrong?
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
# Load the pre-trained model and tokenizer
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name)
# Example sentence for predicting the next word
sentence = "I want to go to the"
# Tokenize the sentence
tokens = tokenizer.tokenize(sentence)
# Convert tokens to token IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# Add [CLS] and [SEP] tokens to the token IDs
token_ids = [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id]
# Create tensor input with the token IDs
input_ids = torch.tensor([token_ids])
# Get the predictions for the next word using top-k sampling
with torch.no_grad():
outputs = model(input_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
# Apply top-k sampling to obtain the predicted next word
top_k = 5 # Number of top-k predictions to consider
probabilities = torch.softmax(predictions, dim=-1)
top_k_predictions = torch.topk(probabilities, k=top_k)
predicted_token_ids = top_k_predictions.indices.tolist()
# Convert predicted token IDs to actual words
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
# Print the predicted next words
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
@steve-landiss
DistilBERT model is trained to predict masked or missing words in a sentence. However, it's important to note that the models are not guaranteed to always produce meaningful results. DistilBERT generates outputs based on the probabilities learned during training, but they can still produce nonsensical outputs. To improve the quality, you can fine-tune it with a dataset you have. Also, there are a couple of ways to get better results, like 1. increasing the value of top_k may give you a broader range of predicted words. 2. Ensembling: Instead of relying on a single language model, you can use an ensemble of multiple models. 3. Using larger models: Consider using a larger language model, like BERT or GPT-2. 4. Post-processing: Apply post-processing techniques to refine the model's outputs. You can even eliminate some of the outputs that you may get, like the "period" you said. 5. Context window: Adjust the context window size used for generating predictions. Here I provide you the code with some of these adjustments that may give you a deeper understanding of how to play with that:
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
import string
import nltk
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name)
####################################################################
sentence = "I want to go to"
context_window = 10 # Adjust the context window size
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id]
input_ids = torch.tensor([token_ids])
####################################################################
with torch.no_grad():
outputs = model(input_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
temperature = 0.8 # Adjust the temperature value
# Sampling
probabilities = torch.softmax(predictions / temperature, dim=-1)
sampled_token_ids = torch.multinomial(probabilities, num_samples=top_k)
predicted_token_ids = sampled_token_ids.tolist()
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Top-k Sampling
top_k = 15 # Adjust the top-k value
topk_probabilities, topk_indices = torch.topk(probabilities, k=top_k)
sampled_token_ids = torch.multinomial(topk_probabilities.squeeze(), num_samples=1)
predicted_token_ids = topk_indices.squeeze(0)[sampled_token_ids].tolist()
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Beam Search
beam_width = 10 # Adjust the beam width
predicted_token_ids = []
for _ in range(beam_width):
sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
predicted_token_ids.append(sampled_token_ids.item())
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Promoting context preservation
context_ids = input_ids[:, -context_window:] # Select the last few tokens as context
with torch.no_grad():
outputs = model(context_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
# Cleaning or filtering predictions: you can filter out the special tokens that you may have in your vocabulary
# this is a typical way to narrow down the vocabulary
filtered_predictions = []
for token_id in predicted_token_ids:
predicted_word = tokenizer.convert_ids_to_tokens([token_id])[0]
if predicted_word not in ["[CLS]", "[SEP]", "[PAD]"]:
filtered_predictions.append(predicted_word)
####################################################################
# Experimenting with different models
# here I give you the example model with gpt2 but you can also use different models like BERT, RoBERTa, etc.
nltk.download('words')
model_name = 'gpt2' # TODO: try to use different models here
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Rest of the code remains the same
with torch.no_grad():
outputs = model(context_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
top_k = 20 # Number of top-k predictions to consider
probabilities = torch.softmax(predictions, dim=-1)
sampled_token_ids = torch.multinomial(probabilities, num_samples=top_k)
predicted_token_ids = sampled_token_ids.tolist()
####################################################################
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
# It is even possible to do post processing on the outputs:
# Like Im trying to Filter out non-English words and punctuation
english_words = set(nltk.corpus.words.words())
punctuation = set(string.punctuation)
filtered_predictions = []
for word in predicted_words:
# Check if the word is an English word and not punctuation, # TODO: you can add more conditions here
if word in english_words and word not in punctuation:
filtered_predictions.append(word)
# Apply additional post-processing rules if needed
modified_predictions = []
for word in filtered_predictions:
# Apply specific rules to modify the word if necessary
# For example, convert to lowercase, remove leading/trailing whitespace, etc.
modified_word = word.lower().strip()
modified_predictions.append(modified_word)
# Print the modified predictions
print("Modified Predicted Words:")
print(f"Original Sentence: {sentence}")
for word in modified_predictions:
print(word)