Search code examples
pythonarraysnumpycosine-similarity

Efficient way to compute cosine similarity between 1D array and all rows in a 2D array


I have one 1D array of shape (300, ) and a 2D array of shape (400, 300). Now, I want to compute the cosine similarity between each of the rows in this 2D array to the 1D array. Thus, my result should be of shape (400, ) which represents how similar these vectors are.

My initial idea is to iterate thru the rows in 2D array using a for loop and then compute cosine similarity between vectors. Is there a faster alternative using broadcasting method?

Here is a contrived example:

In [29]: vec = np.random.randn(300,)
In [30]: arr = np.random.randn(400, 300)

Below is the way I want to calculate the similarity between 1D arrays:

inn = (vec * arr[0]).sum()  
vecnorm = numpy.sqrt((vec * vec).sum())  
rownorm = numpy.sqrt((arr[0] * arr[0]).sum())  
similarity_score = inn / vecnorm / rownorm  

How can I generalize this to arr[0] being replaced with a 2D array?


Solution

  • Here's one following the same method as with @Bi Rico's post, but with einsum for the norm computations -

    den = np.sqrt(np.einsum('ij,ij->i',arr,arr)*np.einsum('j,j',vec,vec))
    out = arr.dot(vec) / den
    

    Also, we can use vec.dot(vec) to replace np.einsum('j,j',vec,vec) for some marginal improvement.

    Timings -

    In [45]: vec = np.random.randn(300,)
        ...: arr = np.random.randn(400, 300)
    
    # @Bi Rico's soln with norm
    In [46]: %timeit (np.linalg.norm(arr, axis=1) * np.linalg.norm(vec))
    10000 loops, best of 3: 100 µs per loop
    
    In [47]: %timeit np.sqrt(np.einsum('ij,ij->i',arr,arr)*np.einsum('j,j',vec,vec))
    10000 loops, best of 3: 77.4 µs per loop
    

    On bigger arrays -

    In [48]: vec = np.random.randn(3000,)
        ...: arr = np.random.randn(4000, 3000)
    
    In [49]: %timeit (np.linalg.norm(arr, axis=1) * np.linalg.norm(vec))
    10 loops, best of 3: 22.2 ms per loop
    
    In [50]: %timeit np.sqrt(np.einsum('ij,ij->i',arr,arr)*np.einsum('j,j',vec,vec))
    100 loops, best of 3: 8.18 ms per loop