@steve-landiss
DistilBERT model is trained to predict masked or missing words in a sentence. However, it's important to note that the models are not guaranteed to always produce meaningful results. DistilBERT generates outputs based on the probabilities learned during training, but they can still produce nonsensical outputs.
To improve the quality, you can fine-tune it with a dataset you have. Also, there are a couple of ways to get better results, like 1. increasing the value of top_k may give you a broader range of predicted words. 2. Ensembling: Instead of relying on a single language model, you can use an ensemble of multiple models. 3. Using larger models: Consider using a larger language model, like BERT or GPT-2. 4. Post-processing: Apply post-processing techniques to refine the model's outputs. You can even eliminate some of the outputs that you may get, like the "period" you said. 5. Context window: Adjust the context window size used for generating predictions. Here I provide you the code with some of these adjustments that may give you a deeper understanding of how to play with that:
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
import string
import nltk
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name)
####################################################################
sentence = "I want to go to"
context_window = 10 # Adjust the context window size
tokens = tokenizer.tokenize(sentence)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id]
input_ids = torch.tensor([token_ids])
####################################################################
with torch.no_grad():
outputs = model(input_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
temperature = 0.8 # Adjust the temperature value
# Sampling
probabilities = torch.softmax(predictions / temperature, dim=-1)
sampled_token_ids = torch.multinomial(probabilities, num_samples=top_k)
predicted_token_ids = sampled_token_ids.tolist()
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Top-k Sampling
top_k = 15 # Adjust the top-k value
topk_probabilities, topk_indices = torch.topk(probabilities, k=top_k)
sampled_token_ids = torch.multinomial(topk_probabilities.squeeze(), num_samples=1)
predicted_token_ids = topk_indices.squeeze(0)[sampled_token_ids].tolist()
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Beam Search
beam_width = 10 # Adjust the beam width
predicted_token_ids = []
for _ in range(beam_width):
sampled_token_ids = torch.multinomial(probabilities, num_samples=1)
predicted_token_ids.append(sampled_token_ids.item())
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
####################################################################
# Promoting context preservation
context_ids = input_ids[:, -context_window:] # Select the last few tokens as context
with torch.no_grad():
outputs = model(context_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
# Cleaning or filtering predictions: you can filter out the special tokens that you may have in your vocabulary
# this is a typical way to narrow down the vocabulary
filtered_predictions = []
for token_id in predicted_token_ids:
predicted_word = tokenizer.convert_ids_to_tokens([token_id])[0]
if predicted_word not in ["[CLS]", "[SEP]", "[PAD]"]:
filtered_predictions.append(predicted_word)
####################################################################
# Experimenting with different models
# here I give you the example model with gpt2 but you can also use different models like BERT, RoBERTa, etc.
nltk.download('words')
model_name = 'gpt2' # TODO: try to use different models here
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Rest of the code remains the same
with torch.no_grad():
outputs = model(context_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
####################################################################
top_k = 20 # Number of top-k predictions to consider
probabilities = torch.softmax(predictions, dim=-1)
sampled_token_ids = torch.multinomial(probabilities, num_samples=top_k)
predicted_token_ids = sampled_token_ids.tolist()
####################################################################
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
# It is even possible to do post processing on the outputs:
# Like Im trying to Filter out non-English words and punctuation
english_words = set(nltk.corpus.words.words())
punctuation = set(string.punctuation)
filtered_predictions = []
for word in predicted_words:
# Check if the word is an English word and not punctuation, # TODO: you can add more conditions here
if word in english_words and word not in punctuation:
filtered_predictions.append(word)
# Apply additional post-processing rules if needed
modified_predictions = []
for word in filtered_predictions:
# Apply specific rules to modify the word if necessary
# For example, convert to lowercase, remove leading/trailing whitespace, etc.
modified_word = word.lower().strip()
modified_predictions.append(modified_word)
# Print the modified predictions
print("Modified Predicted Words:")
print(f"Original Sentence: {sentence}")
for word in modified_predictions:
print(word)