I have an ndarray
called labels
with a shape of (6000, 8)
. This is 6000 one-hot encoded arrays with 8 categories. I want to search for labels that looks like this:
[1,0,0,0,0,0,0,0]
and then tried to do like this
np.where(labels==[1,0,0,0,0,0,0,0,0])
but this does not produce the expected result
You need all
along the second axis:
np.where((labels == [1,0,0,0,0,0,0,0]).all(1))
See with this smaller example:
labels = np.array([[1,0,0,1,0,0,0,0],
[0,0,0,0,0,1,1,0],
[1,0,0,0,0,0,0,0],
[0,0,0,0,0,0,0,1]])
(labels == [1,0,0,0,0,0,0,0])
array([[ True, True, True, False, True, True, True, True],
[False, True, True, True, True, False, False, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, True, True, True, False]])
Note that the above comparisson simply returns an array of the same shape as labels
, since the comparisson has taken place along the rows of labels
. You need to aggregate with all
, to check whether all elements in a row are True
:
(labels == [1,0,0,0,0,0,0,0]).all(1)
#array([False, False, True, False])