Search code examples
pythonscikit-learnclassificationrandom-forestone-hot-encoding

Random Forest predicting neither class when target is one hot encoded


I fairly know that trees are sensitive to one hot encoded (OHE) targets however I want to understand why it returns the predictions like this:

array([[0, 0, 0, 0],
       [0, 0, 0, 0],
            .
            .
            .
       [0, 0, 0, 0],
       [0, 0, 0, 0],
       [0, 0, 1, 0],
       [0, 0, 0, 0]])

For most of the samples, it predict neither class. I will encode my targets as ordinal (since it is applicable) but what if it was not? What to do then? This is how it looks before OHE:

array(['4 -8 weeks', '13 - 16 weeks', '17 - 20 weeks', ..., '9 - 12 weeks',
       '13 - 16 weeks'], dtype=object)

Full code:

from sklearn.preprocessing import LabelBinarizer
mlb = LabelBinarizer()
b = mlb.fit_transform(Class)
list(mlb.classes_)

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(data, b, test_size=0.2, random_state=42)

# Create a multi-label classifier
classifier = RandomForestClassifier()

# Train the classifier
classifier.fit(X_train, y_train)

# Make predictions on the test set
y_pred = classifier.predict(X_test)

accuracy = accuracy_score(y_test, y_pred)

Solution

  • When the targets are one-hot encoded, sklearn treats the problem as a multi-label one (each row could have any number of labels). As such, you get a predicted probability for each label, and those are independently thresholded at 0.5 in order to make the class predictions.

    When the targets are ordinally encoded, sklearn treats the problem as a multiclass one (each row has exactly one class). Despite the numerical ordering, sklearn doesn't care (well, except in tiebreaking) and treats the classes as unordered. The predicted probabilities sum to 1, and the predicted class is the one with largest probability.

    You don't need to encode labels at all. sklearn will encode them internally for computational efficiency; but leaving strings as the labels is fine, will be treated as multiclass, and allows for the class predictions to also be strings (no need to decode).