Search code examples
pythonlistnumpynumpy-ndarrayindices

Finding the index of elements in an array/list based on another list or array


I have two lists/arrays, I want to find the index of elements in one list if the same number exists in another list. Here's an example

 list_A = [1,7,9,7,11,1,2,3,6,4,9,0,1]
 list_B = [9,1,7] 
 #output required : [0,1,2,3,5,10,12]

Any method to do this using hopefully numpy


Solution

  • Using a list-comprehension and enumerate():

    >>> list_A = [1,7,9,7,11,1,2,3,6,4,9,0,1]
    >>> list_B = [9,1,7]
    >>> [i for i, x in enumerate(list_A) if x in list_B]
    [0, 1, 2, 3, 5, 10, 12]
    

    Using numpy:

    >>> import numpy as np
    >>> np.where(np.isin(list_A, list_B))
    (array([ 0,  1,  2,  3,  5, 10, 12], dtype=int64),)
    

    In addition, as @Chris_Rands points out, we could also convert list_B to a set first, as in is O(1) for sets as opposed to O(n) for lists.

    Time comparison:

    import random
    import numpy as np
    import timeit
    
    list_A = [random.randint(0,100000) for _ in range(100000)]
    list_B = [random.randint(0,100000) for _ in range(50000)]
    
    array_A = np.array(A)
    array_B = np.array(B)
    
    def lists_enumerate(list_A, list_B):
        return [i for i, x in enumerate(list_A) if x in set(list_B)]
    
    def listB_to_set_enumerate(list_A, list_B):
        set_B = set(list_B)
        return [i for i, x in enumerate(list_A) if x in set_B]
    
    def numpy(array_A, array_B):
        return np.where(np.isin(array_A, array_B))
    

    Results:

    >>> %timeit lists_enumerate(list_A, list_B)
    48.8 s ± 638 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    >>> %timeit listB_to_set_enumerate(list_A, list_B)
    11.2 ms ± 856 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    >>> %timeit numpy(array_A, array_B)
    23.3 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    So clearly for larger lists the best solution is to either convert list_B to a set before applying the enumerate, or use numpy.