Search code examples
sqlalchemypgvector

Performing vector search with sqlalchemy and pgvector


I am trying to implement hybrid search in postgresql with pgvector and sqlalchemy.

Below is the table schema:

class Project_images(Base):
    __tablename__ = "project_images"
    id = Column(Integer, Sequence("project_image_id_seq"), primary_key=True)
    image_link = Column(String(255))
    image_vector = Column(Vector(512))
    keywords = Column(String(255))
    keyword_vector = Column(Vector(768))

And this is the function I call to perform the search:

def query_db(
        image_encoding,
        image_search_weight,
        keyword_encoding,
        keyword_search_weight,
    ):
        search_query = text(
            """
                SELECT *, 
                ((:image_encoding <=> image_vector) * :image_search_weight + (:keyword_encoding <=> keyword_vector) * :keyword_search_weight) 
                AS vector_sum
                FROM project_images
                ORDER BY vector_sum
                LIMIT 50

            """
        )

        params = {
            "image_encoding": image_encoding,
            "image_search_weight": image_search_weight,
            "keyword_encoding": keyword_encoding,
            "keyword_search_weight": keyword_search_weight,
        }

        with session_class() as session:
            result = session.execute(search_query, params)
            return result

Trying to call query_db gives me the following error:

sqlalchemy.exc.ProgrammingError: (psycopg2.errors.UndefinedFunction) operator does not exist: record <=> vector LINE 3: ...02891638, 0.08573333, -0.011385784, -0.020549707) <=> image_... ^ HINT: No operator matches the given name and argument types. You might need to add explicit type casts.

I have tried session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")), but I still run into the same error.


Solution

  • I don't use this library but I was able to get your code to execute using the libraries it seems you are using. It wasn't clear to me if <=> is commutative or not. The python wrapping for the extension seems to support using the operation as a method and that seemed to solve the problem.

    class ProjectImage(Base):
        __tablename__ = "project_images"
        id = Column(Integer, Sequence("project_image_id_seq"), primary_key=True)
        image_link = Column(String(255))
        image_vector = Column(Vector(512))
        keywords = Column(String(255))
        keyword_vector = Column(Vector(768))
    
    
    metadata.create_all(engine)
    
    
    def query_db(
            image_encoding,
            image_search_weight,
            keyword_encoding,
            keyword_search_weight,
        ):
            # You can put this in the select() but I define it
            # first so you can see it.
            computed_col = (
                (ProjectImage.image_vector.cosine_distance(image_encoding)*image_search_weight)
                + (ProjectImage.keyword_vector.cosine_distance(keyword_encoding)*keyword_search_weight)
            ).label("vector_sum")
    
            search_query = select(
                computed_col
            ).order_by(
                # Reference the computed column.
                text("vector_sum")
            ).limit(50)
    
            with Session(engine) as session:
                result = session.execute(search_query)
                return result
    # This is just dummy data I put in which appears to match what would be passed in.
    query_db([1]*512, 1, [1]*768, 1)
    
    

    This is output of the query with echo=True set on the engine.

    SELECT (project_images.image_vector <=> %(image_vector_1)s) * %(param_1)s + (project_images.keyword_vector <=> %(keyword_vector_1)s) * %(param_2)s AS vector_sum 
    FROM project_images ORDER BY vector_sum 
     LIMIT %(param_3)s
    

    If you want the ProjectImage and the vector_sum you should be able to do this:

        search_query = select(
            ProjectImage, computed_col
        )
       # Then later...
       for (project_image, vector_sum) in session.execute(search_query):
           print(project_image, vector_sum)
    
    

    I used this pgvector-python-sqlalchemy to find the cosine distance method and other examples.