Search code examples
pythonarraysnumpynearest-neighbor

Compute closest neghtbor between 2 numpy arrays - KDTree


I have 2 numpy arrays: a (smaller) array consisting of int values, b (larger) array consisting of float values. The idea is that b contains float values which are close to some int values in a. For example, as a toy example, I have the code below. The arrays aren't sorted like this and I use np.sort() on both a and b to get:

a = np.array([35, 11, 48, 20, 13, 31, 49])
b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])

For each element in a, there are multiple float values in b and the goal is to get the closest value in b for each element in a.

To naively achieve this, I use a for loop:

for e in a:
    idx = np.abs(e - b).argsort()
    print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
11 has nearest match = 11.2890
13 has nearest match = 12.8700
20 has nearest match = 20.0500
31 has nearest match = 31.0300
35 has nearest match = 34.9900
48 has nearest match = 48.1000
49 has nearest match = 49.2000
'''

There can be values in a not existing in b and vice-versa.

a.size = 2040 and b.size = 1041901

To construct a KD-Tree:

# Construct KD-Tree using and query nearest neighnor-
kd_tree = KDTree(data = np.expand_dims(a, 1))
dist_nn, idx_nn = kd_tree.query(x = np.expand_dims(b, 1), k = [1])


dist.shape, idx.shape
# ((19, 1), (19, 1))

To get nearest neighbor in 'b' with respect to 'a', I do:

b[idx]
'''
array([[10.7  ],
       [10.7  ],
       [10.7  ],
       [11.289],
       [11.289],
       [11.289],
       [11.3  ],
       [11.3  ],
       [11.3  ],
       [12.32 ],
       [12.32 ],
       [12.32 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [12.87 ],
       [13.5  ],
       [13.5  ],
       [18.78 ]])
'''

Problems:

  • It seems that KD-Tree doesn't go beyond value 20 in 'a'. [31, 25, 48, 49] in a are completely missed
  • And most of the nearest neighbors it finds is wrong when compared to output of for loop!!

What's going wrong?


Solution

  • If you want to get the closest element for each entry in a, you build your KD-Tree for b and then query a.

    from scipy import spatial
    
    kd = spatial.KDTree(b[:,np.newaxis])
    distances, indices = kd.query(a[:, np.newaxis])
    values = b[indices]
    
    for ai, bi in zip(a, values):
        print(f"{ai} has nearest match = {bi:.4f}")
    
    35 has nearest match = 34.9900
    11 has nearest match = 11.2890
    48 has nearest match = 48.1000
    20 has nearest match = 20.0500
    13 has nearest match = 12.8700
    31 has nearest match = 31.0300
    49 has nearest match = 49.2000