Search code examples
pythonarraysnumpydistance

extract the N closest pairs from a numpy distance array


I have a large, symmetric, 2D distance array. I want to get closest N pairs of observations.

The array is stored as a numpy condensed array, and has of the order of 100 million observations.

Here's an example to get the 100 closest distances on a smaller array (~500k observations), but it's a lot slower than I would like.

import numpy as np
import random
import sklearn.metrics.pairwise
import scipy.spatial.distance

N = 100
r = np.array([random.randrange(1, 1000) for _ in range(0, 1000)])
c = r[:, None]

dists = scipy.spatial.distance.pdist(c, 'cityblock')

# these are the indices of the closest N observations
closest = dists.argsort()[:N]

# but it's really slow to get out the pairs of observations
def condensed_to_square_index(n, c):
    # converts an index in a condensed array to the 
    # pair of observations it represents
    # modified from here: http://stackoverflow.com/questions/5323818/condensed-matrix-function-to-find-pairs
    ti = np.triu_indices(n, 1)
    return ti[0][c]+ 1, ti[1][c]+ 1

r = []
n = np.ceil(np.sqrt(2* len(dists)))
for i in closest:
    pair = condensed_to_square_index(n, i)
    r.append(pair)

It seems to me like there must be quicker ways to do this with standard numpy or scipy functions, but I'm stumped.

NB If lots of pairs are equidistant, that's OK and I don't care about their ordering in that case.


Solution

  • You can speed up the location of the minimum values very notably if you are using numpy 1.8 using np.partition:

    def smallest_n(a, n):
        return np.sort(np.partition(a, n)[:n])
    
    def argsmallest_n(a, n):
        ret = np.argpartition(a, n)[:n]
        b = np.take(a, ret)
        return np.take(ret, np.argsort(b))
    
    dists = np.random.rand(1000*999//2) # a pdist array
    
    In [3]: np.all(argsmallest_n(dists, 100) == np.argsort(dists)[:100])
    Out[3]: True
    
    In [4]: %timeit np.argsort(dists)[:100]
    10 loops, best of 3: 73.5 ms per loop
    
    In [5]: %timeit argsmallest_n(dists, 100)
    100 loops, best of 3: 5.44 ms per loop
    

    And once you have the smallest indices, you don't need a loop to extract the indices, do it in a single shot:

    closest = argsmallest_n(dists, 100)
    tu = np.triu_indices(1000, 1)
    pairs = np.column_stack((np.take(tu[0], closest),
                             np.take(tu[1], closest))) + 1