Search code examples
pythonnumpyone-hot-encoding

Search for a one-hot encoded label in ndarray


I have an ndarray called labels with a shape of (6000, 8). This is 6000 one-hot encoded arrays with 8 categories. I want to search for labels that looks like this:

[1,0,0,0,0,0,0,0]

and then tried to do like this

np.where(labels==[1,0,0,0,0,0,0,0,0])

but this does not produce the expected result


Solution

  • You need all along the second axis:

    np.where((labels == [1,0,0,0,0,0,0,0]).all(1))
    

    See with this smaller example:

    labels = np.array([[1,0,0,1,0,0,0,0], 
                       [0,0,0,0,0,1,1,0], 
                       [1,0,0,0,0,0,0,0], 
                       [0,0,0,0,0,0,0,1]])
    
    (labels == [1,0,0,0,0,0,0,0])
    
    array([[ True,  True,  True, False,  True,  True,  True,  True],
           [False,  True,  True,  True,  True, False, False,  True],
           [ True,  True,  True,  True,  True,  True,  True,  True],
           [False,  True,  True,  True,  True,  True,  True, False]])
    

    Note that the above comparisson simply returns an array of the same shape as labels, since the comparisson has taken place along the rows of labels. You need to aggregate with all, to check whether all elements in a row are True:

    (labels == [1,0,0,0,0,0,0,0]).all(1)
     #array([False, False,  True, False])