Search code examples
pythonscikit-learnnearest-neighborkdtree

nearest neighbour search kdTree


To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too large to use a brute force approach so a KDtree seems best.

Rather than implement one from scratch I see that sklearn.neighbors.KDTree can find the nearest neighbours. Can this be used to find the nearest neighbours of each particle, i.e return a dim(N) list?


Solution

  • This question is very broad and missing details. It's unclear what you did try, how your data looks like and what a nearest-neighbor is (identity?).

    Assuming you are not interested in the identity (with distance 0), you can query the two nearest-neighbors and drop the first column. This is probably the easiest approach here.

    Code:

     import numpy as np
     from sklearn.neighbors import KDTree
     np.random.seed(0)
     X = np.random.random((5, 2))  # 5 points in 2 dimensions
     tree = KDTree(X)
     nearest_dist, nearest_ind = tree.query(X, k=2)  # k=2 nearest neighbors where k1 = identity
     print(X)
     print(nearest_dist[:, 1])    # drop id; assumes sorted -> see args!
     print(nearest_ind[:, 1])     # drop id 
    

    Output

     [[ 0.5488135   0.71518937]
      [ 0.60276338  0.54488318]
      [ 0.4236548   0.64589411]
      [ 0.43758721  0.891773  ]
      [ 0.96366276  0.38344152]]
     [ 0.14306129  0.1786471   0.14306129  0.20869372  0.39536284]
     [2 0 0 0 1]