Search code examples
pythonnumpymultidimensional-arraynearest-neighbormasked-array

Generate mask array with lowest N valued positions reset per row


Given a 2D array of distances, use argsort to generate an index array, where the first element is the index of the lowest value in the row. Use indexing to select only the first K columns, where K = 3 for example.

position = np.random.randint(100, size=(5, 5))
array([[36, 63,  3, 78, 98],
   [75, 86, 63, 61, 79],
   [21, 12, 72, 27, 23],
   [38, 16, 17, 88, 29],
   [93, 37, 48, 88, 10]])
idx = position.argsort()
array([[2, 0, 1, 3, 4],
   [3, 2, 0, 4, 1],
   [1, 0, 4, 3, 2],
   [1, 2, 4, 0, 3],
   [4, 1, 2, 3, 0]])
idx[:,0:3]
array([[2, 0, 1],
   [3, 2, 0],
   [1, 0, 4],
   [1, 2, 4],
   [4, 1, 2]])

What I would like to do then is create a masked array which when applied to the original position array returns only the indices which yield the k shortest distances.

I based this approach on some code I found which works on a 1 dimensional array.

# https://glowingpython.blogspot.co.uk/2012/04/k-nearest-neighbor-search.html

from numpy import random, argsort, sqrt
from matplotlib import pyplot as plt    

def knn_search(x, D, K):
    """ find K nearest neighbours of data among D """
    ndata = D.shape[1]
    K = K if K < ndata else ndata
    # euclidean distances from the other points
    sqd = sqrt(((D - x[:, :ndata]) ** 2).sum(axis=0))
    idx = argsort(sqd)  # sorting
    # return the indexes of K nearest neighbours
    return idx[:K]

# knn_search test
data = random.rand(2, 5)  # random dataset
x = random.rand(2, 1)  # query point

# performing the search
neig_idx = knn_search(x, data, 2)

figure = plt.figure()
plt.scatter(data[0,:], data[1,:])
plt.scatter(x[0], x[1], c='g')
plt.scatter(data[0, neig_idx], data[1, neig_idx], c='r', marker = 'o')
plt.show()

Solution

  • Here's one way -

    N = 3 # number of points to be set as False per row
    
    # Slice out the first N cols per row
    k_idx = idx[:,:N]
    
    # Initialize output array
    out = np.ones(position.shape, dtype=bool)
    
    # Index into output with k_idx as col indices to reset
    out[np.arange(k_idx.shape[0])[:,None], k_idx] = 0
    

    The last step involves advanced-indexing, which might be a big step if you are new to NumPy, but basically here we are using k_idx to index into columns and we are forming tuples of indexes to index into rows with the range array of np.arange(k_idx.shape[0])[:,None]. More info on advanced-indexing.

    We could improve on performance by using np.argpartition instead of argsort, like so -

    k_idx = np.argpartition(position, N)[:,:N]
    

    Sample input, output for a case to set lowest 3 elements per row as False -

    In [227]: position
    Out[227]: 
    array([[36, 63,  3, 78, 98],
           [75, 86, 63, 61, 79],
           [21, 12, 72, 27, 23],
           [38, 16, 17, 88, 29],
           [93, 37, 48, 88, 10]])
    
    In [228]: out
    Out[228]: 
    array([[False, False, False,  True,  True],
           [False,  True, False, False,  True],
           [False, False,  True,  True, False],
           [ True, False, False,  True, False],
           [ True, False, False,  True, False]], dtype=bool)