Search code examples
pythonnlphaystack

Haystack: save InMemoryDocumentStore and load it in retriever later to save embedding generation time


I am using InMemory Document Store and an Embedding retriever for the Q/A pipeline.

from haystack.document_stores import InMemoryDocumentStore
document_store = InMemoryDocumentStore(embedding_dim =768,use_bm25=True) 
document_store.write_documents(docs_processed)
     
from haystack.nodes import EmbeddingRetriever
retriever_model_path ='downloaded_models\local\my_local_multi-qa-mpnet-base-dot-v1'
retriever = EmbeddingRetriever(document_store=document_store,
                              embedding_model=retriever_model_path,
                              use_gpu=True)

document_store.update_embeddings(retriever=retriever)

As the embedding takes a while, I want to load the embeddings and later use them again in the retriever. (in rest API side). I don't want to use ElasticSearch or Faiss. How can I achieve this using In Memory Store? I tried to use Pickle, but there is no way to store the embeddings. Again, in the embedding retriever, there is no load function.

I tried to do the following:

with open("document_store_res.pkl", "wb") as f:
    pickle.dump(document_store.get_all_documents(), f)

And in the rest API, I am trying to load the document store :

def reader_retriever():
# Load the pickled model        
        with open(os.path.join(settings.BASE_DIR,'\downloaded_models\document_store_res.pkl'), 'rb') as f:
            document_store_new = pickle.load(f)

            retriever_model_path = os.path.join(settings.BASE_DIR, '\downloaded_models\my_local_multi-qa-mpnet-base-dot-v1')

            retriever = EmbeddingRetriever(document_store=document_store_new,
                               embedding_model=retriever_model_path,
                               use_gpu=True)

            document_store_new.update_embeddings(retriever=retriever,
                                batch_size=100)
            farm_reader_path = os.path.join(settings.BASE_DIR, '\downloaded_models\my_local_bert-large-uncased-whole-word-masking-squad2')

            reader = FARMReader(model_name_or_path=farm_reader_path,
                                    use_gpu=True)
            

            return reader, retriever

Solution

  • InMemoryDocumentStore: features and limitations

    From Haystack docs:

    Use the InMemoryDocumentStore, if you are just giving Haystack a quick try on a small sample and are working in a restricted environment that complicates running Elasticsearch or other databases.

    • Slow retrieval on larger datasets.
    • No Approximate Nearest Neighbours (ANN).
    • Not recommended for production.

    Possible lightweight alternatives

    To overcome the limitations of InMemoryDocumentStore, if you don't want to use FAISS or ElasticSearch, you could also consider adopting Qdrant which can run smoothly and lightly on Haystack.

    Pickling InMemoryDocumentStore

    As you can see, I do not recommend this solution. In any case, I would pickle the document store (which also contains the embeddings):

    with open("document_store_res.pkl", "wb") as f:
        pickle.dump(document_store, f)
    

    In the REST API, you can change your method as follows:

    def reader_retriever():
    # Load the pickled model        
        with open(os.path.join(settings.BASE_DIR,'\downloaded_models\document_store_res.pkl'), 'rb') as f:
            document_store_new = pickle.load(f)
    
        retriever_model_path = os.path.join(settings.BASE_DIR, '\downloaded_models\my_local_multi-qa-mpnet-base-dot-v1')
        retriever = EmbeddingRetriever(document_store=document_store_new,
                           embedding_model=retriever_model_path,
                           use_gpu=True)
    
        ### DO NOT UPDATE THE EMBEDDINGS, AS THEY HAVE ALREADY BEEN CALCULATED
        
        farm_reader_path = os.path.join(settings.BASE_DIR, '\downloaded_models\my_local_bert-large-uncased-whole-word-masking-squad2')
        reader = FARMReader(model_name_or_path=farm_reader_path,
                                use_gpu=True)
        
        return reader, retriever