Quickly performing cosine similarity with list of embeddings

I have a list phrases for each of which I want to get the top most match from a set of 25k embedding vectors (emb2_list). I am using cosine similarity for this purpose. Following is the code:

from sentence_transformers import SentenceTransformer, util
import numpy as np
import torch

model = SentenceTransformer('bert-base-nli-stsb-mean-tokens')

emb2_list = np.load("emb2_list.npy") #already encoded, len = 25K

phrases = ['phrase 1','phrase 2','phrase 3','phrase 4',]

for phrase in phrases:
    emb1 = model.encode(phrase)

    cos_sim = []

    for emb2 in emb2_list:
        cos_sim.append(util.pytorch_cos_sim(emb1, emb2)[0][0].item())

    v, i = torch.Tensor(cos_sim).topk(1)

    print(f'phrase:{phrase} match index:{i}')

The issue is that each iteration takes ~1 sec (total ~4 sec in this example). It really becomes problematic once the size of phrases increases (as this is part of an online API).

Is there a better way to find cosine similarity in terms of data structure, batching technique or some kind of approximation/Nearest Neighbour algorithm which might speed up this process?


  • You need to batch compute (1) the sentence encodings and (2) cosine similarities.


    The documentation of sentence_transformers states you can call encode on lists of sentences:

    emb1 = model.encode(phrases)


    Cosine similarity is matrix-matrix multiplication.

    emb2 = torch.tensor(emb2_list)                   # cast to torch tensor
    emb2 /= emb2.norm(dim=-1, p=2).unsqueeze(-1)     # normalize to vector length 
    emb1 /= emb1.norm(dim=-1, p=2).unsqueeze(-1)     # ditto
    sims = emb1 @ emb2.t()                           # matrix-matrix multiply the normalized embeddings

    Now sims[a,b] will contain the similarity of phrases[a] to the embedding emb_list[b].

    Note that the matrix multiplication has memory cost O(mn) for m phrases and n precomputed embeddings. Depending on your usecase, you might need to break it down into chunks.