Search code examples
pythonnumpyindexingparallel-processingwhere-clause

Retrieving all numpy array indices where condition


Say I have a numpy array, a, of elements (which do not repeat) np.array([1,3,5,2,4]), I would like to retrieve the indices a contains [4,2]. Desired output: np.array([3,4]) as these are the indices the requested elements.

So far, I've tried

np.all(np.array([[1,2,3,4]]).transpose(), axis=1, where=lambda x: x in [1,2])
>>> 
array([ True,  True,  True,  True])

But this result does not make sense to me. Elements at indices 2,3 should be False

Perhaps I need to search for one element at a time, but I'd prefer if this operation could be vectorized/fast.


Solution

  • I'd say the function you're looking for is numpy.isin()

    arr = np.array([[1,2,3,4]])
    print(np.where(np.isin(arr, [1,2])))
    

    Should give the output you're looking for