Search code examples
pythonnumpywhere-clause

Numpy 'where' on string


I would like to use the numpy.where function on a string array. However, I am unsuccessful in doing so. Can someone please help me figure this out?

For example, when I use numpy.where on the following example I get an error:

import numpy as np

A = ['apple', 'orange', 'apple', 'banana']

arr_index = np.where(A == 'apple',1,0)

I get the following:

>>> arr_index
array(0)
>>> print A[arr_index]
>>> apple

However, I would like to know the indices in the string array, A where the string 'apple' matches. In the above string this happens at 0 and 2. However, the np.where only returns 0 and not 2.

So, how do I make numpy.where work on strings? Thanks in advance.


Solution

  • print(a[arr_index])
    

    not array_index!!

    a = np.array(['apple', 'orange', 'apple', 'banana'])
    
    arr_index = np.where(a == 'apple')
    
    print(arr_index)
    
    print(a[arr_index])