Given a 2D array of distances, use argsort to generate an index array, where the first element is the index of the lowest value in the row. Use indexing to select only the first K columns, where K = 3 for example.
position = np.random.randint(100, size=(5, 5))
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
idx = position.argsort()
array([[2, 0, 1, 3, 4],
[3, 2, 0, 4, 1],
[1, 0, 4, 3, 2],
[1, 2, 4, 0, 3],
[4, 1, 2, 3, 0]])
idx[:,0:3]
array([[2, 0, 1],
[3, 2, 0],
[1, 0, 4],
[1, 2, 4],
[4, 1, 2]])
What I would like to do then is create a masked array which when applied to the original position array returns only the indices which yield the k shortest distances.
I based this approach on some code I found which works on a 1 dimensional array.
# https://glowingpython.blogspot.co.uk/2012/04/k-nearest-neighbor-search.html
from numpy import random, argsort, sqrt
from matplotlib import pyplot as plt
def knn_search(x, D, K):
""" find K nearest neighbours of data among D """
ndata = D.shape[1]
K = K if K < ndata else ndata
# euclidean distances from the other points
sqd = sqrt(((D - x[:, :ndata]) ** 2).sum(axis=0))
idx = argsort(sqd) # sorting
# return the indexes of K nearest neighbours
return idx[:K]
# knn_search test
data = random.rand(2, 5) # random dataset
x = random.rand(2, 1) # query point
# performing the search
neig_idx = knn_search(x, data, 2)
figure = plt.figure()
plt.scatter(data[0,:], data[1,:])
plt.scatter(x[0], x[1], c='g')
plt.scatter(data[0, neig_idx], data[1, neig_idx], c='r', marker = 'o')
plt.show()
Here's one way -
N = 3 # number of points to be set as False per row
# Slice out the first N cols per row
k_idx = idx[:,:N]
# Initialize output array
out = np.ones(position.shape, dtype=bool)
# Index into output with k_idx as col indices to reset
out[np.arange(k_idx.shape[0])[:,None], k_idx] = 0
The last step involves advanced-indexing
, which might be a big step if you are new to NumPy, but basically here we are using k_idx
to index into columns and we are forming tuples of indexes to index into rows with the range array of np.arange(k_idx.shape[0])[:,None]
. More info on advanced-indexing
.
We could improve on performance by using np.argpartition
instead of argsort
, like so -
k_idx = np.argpartition(position, N)[:,:N]
Sample input, output for a case to set lowest 3
elements per row as False -
In [227]: position
Out[227]:
array([[36, 63, 3, 78, 98],
[75, 86, 63, 61, 79],
[21, 12, 72, 27, 23],
[38, 16, 17, 88, 29],
[93, 37, 48, 88, 10]])
In [228]: out
Out[228]:
array([[False, False, False, True, True],
[False, True, False, False, True],
[False, False, True, True, False],
[ True, False, False, True, False],
[ True, False, False, True, False]], dtype=bool)