Search code examples
pythonnlpword-embeddingsentence-similarity

Combine similar elements in N*N matrix without duplicates


I have a list of sentences, and I want to find all the sentences similar to it and put them together in a list/tuple.

I formed sentence embeddings for them, then computed an N*N cosine similarity matrix for N sentences. I then iterated through the elements, and picked the ones higher than a threshold.

If sentences[x] is similar to sentences[y] and sentences[z], if I combine sentences[x] and sentences[y], sentences[x] should not combine with sentences[z] as the loop iterates further

I went with the intuition that since we are comparing cosine similarities, if X is similar to Y, and Y is similar to Z, X will be similar to Z as well, so I should not have to worry about it. My goal is to not have duplicates, but I am stuck.

Is there a better way / what's the best way to do this?

Here is my code:

import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim, pytorch_cos_sim

embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
threshold = 0.75

def form_embeddings(sentences, embedding_model=embedding_model):

    if isinstance(sentences, str):
        sentences = [sentences]

    return embedding_model.encode(sentences, convert_to_tensor=True)

df = pd.read_csv('sample_file.csv')
sentences = df['sentences'].tolist()

#form embeddings
sentence_embeddings = form_embeddings(sentences=sentences)

#form similarity matrix
sim_matrix = pytorch_cos_sim(sentence_embeddings, sentence_embeddings)

#set similarity with itself as zero
sim_matrix.fill_diagonal_(0)

#iterate through and find pairs of similarity
pairs = []
for i in range(len(sentences)):
    for j in range(i, len(sentences)):
        if sim_matrix[i,j] >= threshold:
            pairs.append({'index':[i,j], 'score': sim_matrix[i,j], 'original_sentence':sentences[i], 'similar_sentence':sentences[j]})

Solution

  • I figured out a better way to do this.

    This is solved by a fast clustering implementation.