Search code examples
pythonpytorchhuggingface-transformerslarge-language-modelgpt-2

How use past_key_values in pytorch for caching


why using past_key_values of "hello, my dog is" plus the input id of "cute" doesn't output the same as using "hello, my dog is cute"

In my understanding the past_key_values are past calculations, and in theory they can be used as a kind of cache, but I don't understand how

thanks for your reply

from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn.functional as F
import random

torch.manual_seed(42)
random.seed(42)


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
with torch.no_grad():
    ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt")
    print("ids : ", ids)
    output = model.generate(
        ids, max_length=1, top_k=50, do_sample=True, num_return_sequences=1
    )
    print("with all ids : ", tokenizer.decode(output[0][-1], skip_special_tokens=True))


output = model(input_ids=ids)
logits = output.logits
probabilities = F.softmax(logits, dim=-1)
next_word_index = torch.multinomial(probabilities.squeeze(0), 1)
next_word = tokenizer.decode(next_word_index.tolist()[0])
print("with all ids using multinomial : ", next_word)

uncomplete_ids = ids[:, :-1]
print("uncomplete ids : ", uncomplete_ids)
output = model(input_ids=uncomplete_ids, use_cache=True)
past_key_values = output.past_key_values

last_id = ids[:, -1:]
print("last id : ", last_id)
output = model(input_ids=last_id, past_key_values=past_key_values)
logits = output.logits
probabilities = F.softmax(logits, dim=-1)
next_word_index = torch.multinomial(probabilities.squeeze(0), 1)
next_word = tokenizer.decode(next_word_index.tolist()[0])
print("using past_key_values : ", next_word)

Solution

  • The issue you're experiencing is because you're not appending the "cute" token to the uncomplete_ids when using past_key_values. Here's how you can correctly use past_key_values for caching:
    
    
    from transformers import GPT2Tokenizer, GPT2LMHeadModel
    import torch
    import torch.nn.functional as F
    import random
    
    torch.manual_seed(42)
    random.seed(42)
    
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)
    
    # Encode the initial text
    with torch.no_grad():
        ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt")
    
    # Generate a continuation with the complete context
    output = model.generate(
        ids, max_length=5, top_k=50, do_sample=True, num_return_sequences=1
    )
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    print("Generated text with complete context:", generated_text)
    
    # Split the input into two parts: "Hello, my dog is" and "cute"
    context_ids = ids[:, :-1]
    next_word_ids = ids[:, -1:]
    
    # Generate a continuation using caching (past_key_values)
    output = model(input_ids=context_ids, use_cache=True)
    past_key_values = output.past_key_values
    
    # Generate the next word based on the cached context
    output = model(input_ids=next_word_ids, past_key_values=past_key_values)
    logits = output.logits
    probabilities = F.softmax(logits, dim=-1)
    next_word_index = torch.multinomial(probabilities.squeeze(0), 1)
    next_word = tokenizer.decode(next_word_index.tolist()[0])
    print("Generated text using past_key_values:", next_word)