Search code examples
pythonpandaskerasone-hot-encodingmultilabel-classification

One hot encoding of multi label images in keras


I am using PASCAL VOC 2012 dataset for image classification. A few images have multiple labels where as a few of them have single labels as shown below.

    0  2007_000027.jpg               {'person'}
    1  2007_000032.jpg  {'aeroplane', 'person'}
    2  2007_000033.jpg            {'aeroplane'}
    3  2007_000039.jpg            {'tvmonitor'}
    4  2007_000042.jpg                {'train'}

I want to do one-hot encoding of these labels to train the model. However, I couldn't use keras.utils.to_categorical, as these labels are not integers and pandas.get_dummies is not giving me the results as expected. get_dummies is giving different categories as below, i.e. it is taking each unique combination of labels as one category.

 {'aeroplane', 'bus', 'car'}  {'aeroplane', 'bus'}  {'tvmonitor', 'sofa'}  {'tvmonitor'} ...

What is the best way to one-hot encode these labels as we don't have specific number of labels for each image.


Solution

  • The MultiLabelBinarizer class allow to do one-hot encoding on multilabel sets, like you have in column b:

    print (df)
                     a                        b
    0  2007_000027.jpg               {'person'}
    1  2007_000032.jpg  {'aeroplane', 'person'}
    2  2007_000033.jpg            {'aeroplane'}
    3  2007_000039.jpg            {'tvmonitor'}
    4  2007_000042.jpg                {'train'}
    

    from sklearn.preprocessing import MultiLabelBinarizer
    
    mlb = MultiLabelBinarizer()
    df = pd.DataFrame(mlb.fit_transform(df['b']),columns=mlb.classes_)
    print (df)
       aeroplane  person  train  tvmonitor
    0          0       1      0          0
    1          1       1      0          0
    2          1       0      0          0
    3          0       0      0          1
    4          0       0      1          0
    

    Or Series.str.join with Series.str.get_dummies, but it should be slower in large DataFrame:

    df = df['b'].str.join('|').str.get_dummies()
    print (df)
    
       aeroplane  person  train  tvmonitor
    0          0       1      0          0
    1          1       1      0          0
    2          1       0      0          0
    3          0       0      0          1
    4          0       0      1          0