Search code examples
python-3.xnumpymachine-learningmappingone-hot-encoding

Mapping one-hot encoded target values to proper label names


I have a list of label names which I enuemrated and created a dictionary:

my_list = [b'airplane',
 b'automobile',
 b'bird',
 b'cat',
 b'deer',
 b'dog',
 b'frog',
 b'horse',
 b'ship',
 b'truck']

label_dict =dict(enumerate(my_list))


{0: b'airplane',
 1: b'automobile',
 2: b'bird',
 3: b'cat',
 4: b'deer',
 5: b'dog',
 6: b'frog',
 7: b'horse',
 8: b'ship',
 9: b'truck'}

Now I'm trying to cleaning map/apply the dict value to my target which is in an one-hot-encoded form.

y_test[0]

array([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.])


y_test[0].map(label_dict) should return: 
'cat'

I was playing around with

(lambda key,value: value for y_test[0] == 1)

but couldn't come up with any concrete

Thank you.


Solution

  • Since we are working with one-hot encoded array, argmax could be used to get the index for one off 1 for each row. Thus, using the list as input -

    [my_list[i] for i in y_test.argmax(1)]
    

    Or with np.take to have array output -

    np.take(my_list,y_test.argmax(1))
    

    To work with dict and assuming sequential keys as 0,1,.., we could have -

    np.take(label_dict.values(),y_test.argmax(1))
    

    If the keys are not essentially in sequence but sorted -

    np.take(label_dict.values(), np.searchsorted(label_dict.keys(),y_test.argmax(1)))
    

    Sample run -

    In [79]: my_list
    Out[79]: 
    ['airplane',
     'automobile',
     'bird',
     'cat',
     'deer',
     'dog',
     'frog',
     'horse',
     'ship',
     'truck']
    
    In [80]: y_test
    Out[80]: 
    array([[ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
           [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.]])
    
    In [81]: [my_list[i] for i in y_test.argmax(1)]
    Out[81]: ['cat', 'automobile', 'ship']
    
    In [82]: np.take(my_list,y_test.argmax(1))
    Out[82]: 
    array(['cat', 'automobile', 'ship'], 
          dtype='|S10')