Search code examples
google-cloud-spannervector-databaseretrieval-augmented-generation

Why is this KNN vector query to a Google Spanner database taking over 30 seconds?


I've got a document database with about 6,000 records. I've successfully used vector database backends for RAG queries with these records. I'd like to move to from a vector-specific database to Google's Spanner database.

My Spanner query looks like this:

SELECT a.id, a.title, cosine_distance(embeddings,
  (
    SELECT embeddings.VALUES
    FROM   ml.predict(model embeddingsmodel,
      (
        SELECT 'What is the meaning of life?' AS content)))) as distance
from `documents` as a
order by distance

This query typically takes over 30 seconds. How can I get this to work faster? The planner for spanner says most of the time is in ml.predict and something called "merge distributed union".

Other info:

Embedding model:

CREATE MODEL EmbeddingsModel
INPUT(content STRING(MAX))
OUTPUT(
  embeddings
    STRUCT<
      statistics STRUCT<truncated BOOL, token_count FLOAT64>,
      values ARRAY<FLOAT64>>
)
REMOTE OPTIONS (
  endpoint = '//aiplatform.googleapis.com/projects/<myproject-id>/locations/<mylocation>/publishers/google/models/textembedding-gecko@003'
);

Documents table:

CREATE TABLE
  content (id STRING(36) DEFAULT (GENERATE_UUID()),
    text STRING(MAX),
    title STRING(MAX),
    embeddings ARRAY<FLOAT64>,
    )
PRIMARY KEY
  (id);

Solution

  • It appears that ml.predict() is recalculating the values for the "what is the meaning of life" question for each row in the documents table.

    Creating a temp table gets around this.

    WITH query_embedding AS (
      SELECT
        embeddings.values AS embedding_values
      FROM
        ml.predict(MODEL embeddingsmodel,
          (
          SELECT
            'What is the meaning of life?' AS content))
    )
    
    SELECT
      a.id,
      a.title,
      cosine_distance(a.embeddings, b.embedding_values) AS distance
    FROM
      `documents` AS a, query_embedding as b
    ORDER BY
      distance
    LIMIT 10