I am planning to deploy a fine-tuned version of Open-Orca-Platypus-2. It takes around 13.5GB on the GPU. I tried using g4dn.12xlarge in AWS which has 4 GPUs, but the inference still takes around 40 seconds. I also tried it on A100 GPU provided by Colab, but still the same.
What am I doing wrong? Do I still need more computational power or is anything wrong with my code?
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="auto"
)
# Set the model to evaluation mode
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("Open-Orca/OpenOrca-Platypus2-13B", trust_remote_code=True)
def ask_bot(question):
with torch.no_grad():
# Tokenize input question
input_ids = tokenizer.encode(question, return_tensors="pt").cuda()
# Generate output
output = model.module.generate(
input_ids,
max_length=200,
num_return_sequences=1,
do_sample=True,
top_k=50
)
# Decode and extract the response
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
response = generated_text.split("->:")[-1]
return response
It does take a long time to generate an output even on powerful GPUs. My use-case was a chatbot, so I figured it would be ideal to stream the output token by token as generated by the model. This reduced the perceived time although the actual output would remain the same.
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
with torch.no_grad():
# Tokenize input question
input_ids = tokenizer.encode(question, return_tensors="pt", truncation=True).cuda()
streamer = TextIteratorStreamer(
tokenizer=tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
def generate_and_signal_complete():
output = model.generate(
input_ids,
max_length=1500,
num_return_sequences=1,
do_sample=True,
top_k=50,
streamer=streamer
)
t1 = Thread(target=generate_and_signal_complete)
t1.start()
# Decode and extract the response
for new_text in streamer:
yield new_text