I am working on a project where I want to train/fine-tune chatgpt like on my custom model, and for the same, I am using the below mentioned code. I am able to get the output, however, I want to use GPU for a better speed.
from gpt_index import (SimpleDirectoryReader, GPTListIndex,
readers, GPTSimpleVectorIndex, LLMPredictor, PromptHelper)
from langchain import OpenAI
from types import FunctionType
from llama_index import ServiceContext, GPTVectorStoreIndex
import sys
import os
import time
from llama_index.node_parser import SimpleNodeParser
os.environ["OPENAI_API_KEY"] = "key-here"
parser = SimpleNodeParser()
def construct_index(directory_path):
max_input_size = 4096
num_outputs = 256
max_chunk_overlap = 20
chunk_size_limit = 600
llm_predictor = LLMPredictor(llm=OpenAI(temperature=0, model_name="gpt-3.5-turbo",
# prompt_helper = PromptHelper(max_input_size, num_outputs, max_chunk_overlap,
# chunk_size_limit=chunk_size_limit)
service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor)
documents = SimpleDirectoryReader(directory_path).load_data()
nodes = parser.get_nodes_from_documents(documents)
index = GPTVectorStoreIndex.from_documents(documents, service_context = service_context)
return index
# index = construct_index("docs")
index = GPTVectorStoreIndex.load_from_disk('./jsons/json-schema-gpt-3.5-turbo.json')
conversation_history = []
while True:
user_input = input("You: ")
input_text = "\n".join(conversation_history + [user_input])
start = time.time()
response = index.query(input_text)
response_text = response.response
print(time.time() - start)
# Print the response
print("Bot:", response_text)
if len(conversation_history) > 10:
# Append the current input and response to the conversation history
As it can be seen, I am using the method query
to obtain results. Is there a way that I can update my method so that the computation is faster ?
I figured out that since, we are only making API calls to OpenAI for inferencing/decoding our indexes. We cannot speed the process up. For that we might need to use this and call for open source embeddings present