Search code examples
machine-learningscikit-learnnearest-neighborkdtree

5 Nearest Neighbors Using KD tree


I want to find 5 nearest neighbors for each point of blue points(T-SNE1) from red points(T-SNE2). So I wrote this code just to find out the right way to do that but I am not sure is that right or wrong to do that?

X = np.random.random((10, 2))  # 10 points in 3 dimensions
Y = np.random.random((10, 2))  # 10 points in 3 dimensions
NNlist=[]
treex = KDTree(X, leaf_size=2)
for i in range(len(Y)):
    dist, ind = treex.query([Y[i]], k=5)
    NNlist.append(ind[0][0])
    print(ind)  # indices of 5 closest neighbors
    print(dist)
    print("the nearest index is:" ,ind[0][0],"with distance:",dist[0][0], "for Y",i)
print(NNlist)

enter image description here output

[[9 5 4 6 0]]
[[ 0.21261486  0.32859024  0.41598597  0.42960146  0.43793039]]
the nearest index is: 9 with distance: 0.212614862956 for Y 0
[[0 3 2 6 1]]
[[ 0.10907128  0.11378059  0.13984741  0.18000197  0.27475481]]
the nearest index is: 0 with distance: 0.109071275144 for Y 1
[[8 2 3 0 1]]
[[ 0.21621245  0.30543878  0.40668179  0.4370689   0.49372232]]
the nearest index is: 8 with distance: 0.216212445449 for Y 2
[[8 3 2 6 0]]
[[ 0.16648482  0.2989508   0.40967709  0.42511931  0.46589575]]
the nearest index is: 8 with distance: 0.166484820786 for Y 3
[[1 2 5 0 4]]
[[ 0.15331281  0.25121761  0.29305736  0.30173474  0.44291615]]
the nearest index is: 1 with distance: 0.153312811422 for Y 4
[[2 3 8 0 6]]
[[ 0.20441037  0.20917797  0.25121628  0.2903253   0.33914051]]
the nearest index is: 2 with distance: 0.204410367254 for Y 5
[[2 1 0 3 5]]
[[ 0.08400022  0.1484925   0.17356156  0.32387147  0.33789602]]
the nearest index is: 2 with distance: 0.0840002184199 for Y 6
[[8 2 3 7 0]]
[[ 0.2149891   0.40584999  0.50054235  0.53307269  0.5389266 ]]
the nearest index is: 8 with distance: 0.21498909502 for Y 7
[[1 0 2 5 9]]
[[ 0.07265268  0.11687068  0.19065327  0.20004392  0.30269591]]
the nearest index is: 1 with distance: 0.0726526838766 for Y 8
[[5 9 4 1 0]]
[[ 0.21563204  0.25067242  0.29904262  0.36745386  0.39634179]]
the nearest index is: 5 with distance: 0.21563203953 for Y 9
[9, 0, 8, 8, 1, 2, 2, 8, 1, 5]

Solution

  • import numpy as np
    from scipy.spatial import KDTree
    
    X = np.random.random((10, 2))  # 10 points in 3 dimensions
    Y = np.random.random((10, 2))  # 10 points in 3 dimensions
    NNlist=[]
    
    for i in range(len(X)):
        treey = KDTree(np.concatenate([Y.tolist(), np.expand_dims(X[i], axis=0)], axis=0))
        dist, ind = treey.query([X[i]], k=6)
        print('index', ind)  # indices of 5 closest neighbors
        print('distance', dist)
        print('5 nearest neighbors')
        for j in ind[0][1:]:
            print(Y[j])
        print()
    

    you can get ...

    index [[10  5  8  9  1  2]]
    distance [[ 0.          0.3393312   0.38565112  0.40120109  0.44200758  0.47675255]]
    5 nearest neighbors
    [ 0.6298789   0.18283264]
    [ 0.42952574  0.83918788]
    [ 0.26258905  0.4115705 ]
    [ 0.61789523  0.96261285]
    [ 0.92417172  0.13276541]
    
    index [[10  1  3  8  4  9]]
    distance [[ 0.          0.09176157  0.18219064  0.21845335  0.28876942  0.60082231]]
    5 nearest neighbors
    [ 0.61789523  0.96261285]
    [ 0.51031835  0.99761715]
    [ 0.42952574  0.83918788]
    [ 0.3744326   0.97577322]
    [ 0.26258905  0.4115705 ]
    
    index [[10  7  0  9  5  6]]
    distance [[ 0.          0.15771386  0.2751765   0.3457175   0.49918935  0.70597498]]
    5 nearest neighbors
    [ 0.19803817  0.23495888]
    [ 0.41293849  0.05585981]
    [ 0.26258905  0.4115705 ]
    [ 0.6298789   0.18283264]
    [ 0.04527532  0.78806495]
    
    index [[10  0  5  7  9  2]]
    distance [[ 0.          0.09269963  0.20597988  0.24505542  0.31104979  0.49743673]]
    5 nearest neighbors
    [ 0.41293849  0.05585981]
    [ 0.6298789   0.18283264]
    [ 0.19803817  0.23495888]
    [ 0.26258905  0.4115705 ]
    [ 0.92417172  0.13276541]
    
    index [[10  9  5  7  0  8]]
    distance [[ 0.          0.20406876  0.26125464  0.30645317  0.33369641  0.45509834]]
    5 nearest neighbors
    [ 0.26258905  0.4115705 ]
    [ 0.6298789   0.18283264]
    [ 0.19803817  0.23495888]
    [ 0.41293849  0.05585981]
    [ 0.42952574  0.83918788]
    
    index [[10  5  2  0  7  9]]
    distance [[ 0.          0.13641503  0.17524716  0.34224271  0.56393988  0.56893897]]
    5 nearest neighbors
    [ 0.6298789   0.18283264]
    [ 0.92417172  0.13276541]
    [ 0.41293849  0.05585981]
    [ 0.19803817  0.23495888]
    [ 0.26258905  0.4115705 ]
    
    index [[10  7  9  0  5  6]]
    distance [[ 0.          0.04152391  0.22807566  0.25709252  0.43421854  0.61332497]]
    5 nearest neighbors
    [ 0.19803817  0.23495888]
    [ 0.26258905  0.4115705 ]
    [ 0.41293849  0.05585981]
    [ 0.6298789   0.18283264]
    [ 0.04527532  0.78806495]
    
    index [[10  5  1  2  8  3]]
    distance [[ 0.          0.40641681  0.43652515  0.44861766  0.45186271  0.51705369]]
    5 nearest neighbors
    [ 0.6298789   0.18283264]
    [ 0.61789523  0.96261285]
    [ 0.92417172  0.13276541]
    [ 0.42952574  0.83918788]
    [ 0.51031835  0.99761715]
    
    index [[10  6  9  7  8  4]]
    distance [[ 0.          0.17568369  0.2841519   0.40184611  0.43110847  0.47835169]]
    5 nearest neighbors
    [ 0.04527532  0.78806495]
    [ 0.26258905  0.4115705 ]
    [ 0.19803817  0.23495888]
    [ 0.42952574  0.83918788]
    [ 0.3744326   0.97577322]
    
    index [[10  9  7  5  0  8]]
    distance [[ 0.          0.11723769  0.2275565   0.32111803  0.32446146  0.4643181 ]]
    5 nearest neighbors
    [ 0.26258905  0.4115705 ]
    [ 0.19803817  0.23495888]
    [ 0.6298789   0.18283264]
    [ 0.41293849  0.05585981]
    [ 0.42952574  0.83918788]