Search code examples
pythonpandasspacysimilaritybert-language-model

How to find the similarity of sentences in 2 columns of a dataframe using spacy


I pulled this code from https://spacy.io/universe/project/spacy-sentence-bert

import spacy_sentence_bert
# load one of the models listed at https://github.com/MartinoMensio/spacy-sentence-bert/
nlp = spacy_sentence_bert.load_model('en_roberta_large_nli_stsb_mean_tokens')
# get two documents
doc_1 = nlp('Hi there, how are you?')
doc_2 = nlp('Hello there, how are you doing today?')
# use the similarity method that is based on the vectors, on Doc, Span or Token
print(doc_1.similarity(doc_2[0:7]))

I have a dataframe with 2 columns containing sentences like below. I'm trying to find the similarity between the sentences in each row. I've tried a few different methods but not having much luck so figured I would ask here. Thank you all.

Current df

Sentence1 | Sentence2

Another-Sentence1 | Another-Sentence2

Yet-Another-Sentence1 | Yet-Another-Sentence2

Goal output:

Sentence1 | Sentence2 | Similarity-Score-Sentence1-Sentence2

Another-Sentence1 | Another-Sentence2 | Similarity-Score-Another-Sentence1-Another-Sentence2

Yet-Another-Sentence1 | Yet-Another-Sentence2 | Similarity-Score-Yet-Another-Sentence1-Yet-Another-Sentence2

Solution

  • I assume that your first row consists of headers, the data will start from the next row after header, and also assume that you are using panda to convert csv to dataframe, the below code works in my environment.

    import spacy_sentence_bert
    import pandas as pd
    nlp = spacy_sentence_bert.load_model('en_roberta_large_nli_stsb_mean_tokens')
    df = pd.read_csv('testing.csv')
    similarityValue = []
    
    for i in range(df.count()[0]):
        sentence_1 = nlp(df.iloc[i][0])
        sentence_2 = nlp(df.iloc[i][1])
        similarityValue.append(sentence_1.similarity(sentence_2))
        print(sentence_1, '|', sentence_2, '|', sentence_1.similarity(sentence_2))
    
    df['Similarity'] = similarityValue
    print(df)
    

    Input CSV:

    enter image description here

    Output:

    enter image description here