Search code examples
pythonmachine-learningclassificationone-hot-encodingmultilabel-classification

One-Hot Encoding of label not needed?


I am trying to understand a code block from a guided tutorial for the classic Iris Classification problem.

The code block for the final model is given as follows

chosen_model = SVC(gamma='auto')
chosen_model.fit(X_train,Y_train)
predictions = chosen_model.predict(X_valid)

In this image you can see the data types present in X_train and Y_train. These are Numpy arrays. Y_train contains the Iris species as string.

My question is simple: how come the model works even though I haven't One-Hot Encoded Y_train into different binary columns? My understanding from other tutorials is that for multi-class classification I need to first do one-hot encoding.

The code is working fine, I want to grasp when I need to One-Hot Encode and when it's not needed. Thank you!


Solution

  • I think you might be confusing a multiclass (your case) with a multioutput classification.

    In multiclass classification problems, your output should only be a single target column, and you'll be training the model to classify among the classes in that column. You'd have to split into separate target columns, in the case you had to predict n different classes per sample, which is not the case, you only want one of the targets per sample.

    So for multiclass classification, there's no need to OneHotEncode the target, since you only want a single target column (which can also be categorical in SVC). What you do have to encode, either using OneHotEncoder or with some other encoders, is the categorical input features, which have to be numeric.

    Also, SVC can deal with categorical targets, since it LabelEncode's them internally:

    from sklearn.datasets import load_iris
    from sklearn.svm import SVC
    from sklearn.model_selection import train_test_split
    
    X, y = load_iris(return_X_y=True)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    y_train_categorical = load_iris()['target_names'][y_train]
    # array(['setosa', 'setosa', 'versicolor',...
    
    sv = SVC()
    sv.fit(X_train, y_train_categorical)
    sv.classes_
    # array(['setosa', 'versicolor', 'virginica'], dtype='<U10')