Search code examples
pythonnlpllama-indexlarge-language-model

ImportError: cannot import name 'CustomLLM' from 'llama_index.llms'


I'm having difficulties to work with llama_index. I want to load a custom LLM to use it. Fortunately, they have the exact example for my need on their documentation, unfortunately, it does not work! They have these imports in their example:

from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata

And when I run it I'll get this error:

ImportError: cannot import name 'CustomLLM' from 'llama_index.llms'

My llama_index version is 0.7.1 (the last current version). Do you know any workaround for me to use a custom dataset in llama_index?

P.S. If their full code is needed here it is:

import torch
from transformers import pipeline
from typing import Optional, List, Mapping, Any

from llama_index import (
    ServiceContext, 
    SimpleDirectoryReader, 
    LangchainEmbedding, 
    ListIndex
)
from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata


# set context window size
context_window = 2048
# set number of output tokens
num_output = 256

# store the pipeline/model outisde of the LLM class to avoid memory issues
model_name = "facebook/opt-iml-max-30b"
pipeline = pipeline("text-generation", model=model_name, device="cuda:0", model_kwargs={"torch_dtype":torch.bfloat16})

class OurLLM(CustomLLM):

    @property
    def metadata(self) -> LLMMetadata:
        """Get LLM metadata."""
        return LLMMetadata(
            context_window=context_window, num_output=num_output
        )

    def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
        prompt_length = len(prompt)
        response = pipeline(prompt, max_new_tokens=num_output)[0]["generated_text"]

        # only return newly generated tokens
        text = response[prompt_length:]
        return CompletionResponse(text=text)
    
    def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
        raise NotImplementedError()

# define our LLM
llm = OurLLM()

service_context = ServiceContext.from_defaults(
    llm=llm, 
    context_window=context_window, 
    num_output=num_output
)

# Load the your data
documents = SimpleDirectoryReader('./data').load_data()
index = ListIndex.from_documents(documents, service_context=service_context)

# Query and print response
query_engine = index.as_query_engine()
response = query_engine.query("<query_text>")
print(response)

Solution

  • You need to change your import library

    Change

    from llama_index.llms import CustomLLM, CompletionResponse, LLMMetadata
    

    To this

    from llama_index.llms.custom import CustomLLM
    from llama_index.llms.base import CompletionResponse, LLMMetadata