Search code examples
pythonarraysnumpyvectorbinary-search

Numpy searchsorted on array of vectors in one dimension


I've been Googling this for a while now, even though I think it's a common problem, I don't see there's a solution anywhere on SO.

Say I have an array of 3D vectors (x, y, z), like this:

import numpy as np

arr = np.array(
    [(1, 2, 3), (3, 1, 2.5), (5, 3, 1), (0, -1, 2)],
    dtype=[('x', np.float), ('y', np.float), ('z', np.float)]
)
print(np.sort(arr, order='z'))

This prints:

[(5.,  3., 1. ) (0., -1., 2. ) (3.,  1., 2.5) (1.,  2., 3. )]

I would like to now search this sorted array, by dimension 'z' only. A binary search would be extremely efficient. But searchsorted only works on 1D arrays. And there's no lambda you can apply to each value (basically np.dot with a (0, 0, 1) vector.)

Is there any method to do this in numpy or do I need to implement binary search myself (still an option since it's very fast even in vanilla Python).

For example for value x= 2.5 I'd expect the index 2. And for x=2.4 I'd still expect 2, for x=2.6 I'd expect 3. Either the index or the vector itself (like (3, 1, 2.5)).


Solution

  • Without using tuples in the array you can make use of slicing:

    import numpy as np
    
    arr = np.random.rand(10,3)
    print(arr)
    
    sort_indices = np.argsort(arr[:,2])
    arr_sorted = arr[sort_indices]
    print(arr_sorted)
    
    
    # run search sorted
    search_result = np.searchsorted(arr_sorted[:,2],arr[5,2])
    >>> 2
    

    Output:

    unsorted:
    [[0.71815835 0.89099775 0.51398111]
     [0.56393906 0.26684628 0.33065586]
     [0.38920018 0.0485013  0.70958811]
     [0.3771277  0.95567051 0.18514701]
     [0.59715961 0.19092995 0.09340359]
     [0.09575273 0.56697649 0.10120321]
     [0.63226061 0.95258914 0.59669295]
     [0.1714133  0.7406211  0.23079041]
     [0.33512727 0.23244954 0.08735154]
     [0.50582011 0.97186928 0.15525005]]
    
    sorted:
    [[0.33512727 0.23244954 0.08735154]
     [0.59715961 0.19092995 0.09340359]
     [0.09575273 0.56697649 0.10120321]
     [0.50582011 0.97186928 0.15525005]
     [0.3771277  0.95567051 0.18514701]
     [0.1714133  0.7406211  0.23079041]
     [0.56393906 0.26684628 0.33065586]
     [0.71815835 0.89099775 0.51398111]
     [0.63226061 0.95258914 0.59669295]
     [0.38920018 0.0485013  0.70958811]]