Search code examples
scikit-learnrandom-forestmissing-datadecision-tree

Why does RandomForestClassifier not support missing values while DecisionTreeClassifier does in scikit-learn 1.3?


In the latest scikit-learn release (1.3), it was announced that DecisionTreeClassifier now supports missing values. The implementation evaluates splits with missing values going either to the left or right nodes (see release highlights).

However, when I tried using RandomForestClassifier, which is an ensemble of DecisionTreeClassifiers, it appears that it doesn't support missing values in the same way. I assumed that since RandomForestClassifier inherits from DecisionTreeClassifier, it would also support missing values.

Here's a simple snippet I used for testing:

import numpy as np
from sklearn.ensemble import RandomForestClassifier

X = np.array([0, 1, 6, np.nan]).reshape(-1, 1)
y = [0, 0, 1, 1]

forest = RandomForestClassifier(random_state=0).fit(X, y)
predictions = forest.predict(X)

This throws the following error related to the presence of missing values:

ValueError: Input X contains NaN. RandomForestClassifier does not accept missing values encoded as NaN natively. For supervised learning, you might want to consider sklearn.ensemble.HistGradientBoostingClassifier and Regressor which accept missing values encoded as NaNs natively. Alternatively, it is possible to preprocess the data, for instance by using an imputer transformer in a pipeline or drop samples with missing values. See https://scikit-learn.org/stable/modules/impute.html You can find a list of all estimators that handle NaN values at the following page: https://scikit-learn.org/stable/modules/impute.html#estimators-that-handle-nan-values

The same code with DecisionTreeClassifier works just fine. Can anyone help explain why the RandomForestClassifier doesn't support missing values, despite being an ensemble of DecisionTreeClassifiers?


Solution

  • As pointed out by Ben Reiniger in the comments, people are acutally working on this feature. From scikit-learn's release history I found out that RandomForestClassifier and RandomForestRegressor will support missing values from version 1.4 on: https://scikit-learn.org/dev/whats_new/v1.4.html#sklearn-ensemble