I'm having difficulties to work with llama_index
. I want to load a custom LLM to use it. Fortunately, they have the exact example for my need on their documentation, unfortunately, it does not work!
They have these imports in their example:
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
And when I run it I'll get this error:
ImportError: cannot import name 'CustomLLM' from 'llama_index.llms'
My llama_index
version is 0.7.1 (the last current version). Do you know any workaround for me to use a custom dataset in llama_index?
P.S. If their full code is needed here it is:
import torch
from transformers import pipeline
from typing import Optional, List, Mapping, Any
from llama_index import (
ServiceContext,
SimpleDirectoryReader,
LangchainEmbedding,
ListIndex
)
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
# set context window size
context_window = 2048
# set number of output tokens
num_output = 256
# store the pipeline/model outisde of the LLM class to avoid memory issues
model_name = "facebook/opt-iml-max-30b"
pipeline = pipeline("text-generation", model=model_name, device="cuda:0", model_kwargs={"torch_dtype":torch.bfloat16})
class OurLLM(CustomLLM):
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=context_window, num_output=num_output
)
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
prompt_length = len(prompt)
response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]
# only return newly generated tokens
text = response[prompt_length:]
return CompletionResponse(text=text)
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
raise NotImplementedError()
# define our LLM
llm = OurLLM()
service_context = ServiceContext.from_defaults(
llm=llm,
context_window=context_window,
num_output=num_output
)
# Load the your data
documents = SimpleDirectoryReader('./data').load_data()
index = ListIndex.from_documents(documents, service_context=service_context)
# Query and print response
query_engine = index.as_query_engine()
response = query_engine.query("<query_text>")
print(response)
You need to change your import library
Change
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
To this
from llama_index.llms.custom import CustomLLM
from llama_index.llms.base import CompletionResponse, LLMMetadata