Search code examples
python-3.xtreeviewrandom-forestdtreeviz

'GridSearchCV' object has no attribute 'estimators_' using dtreeviz


After carrying out a GridSearchCV on a Randomforest classifer, I am attempting to display a tree plot. I tried the code below, but I get this error:

AttributeError: 'GridSearchCV' object has no attribute 'estimators_'

Can you tell me how to fix this error and get a view of a tree?

Here is my code from the classifier:

model = RandomForestClassifier()

parameter_space = {
    'n_estimators': [10,50,100],
    'criterion': ['gini', 'entropy'],
    'max_depth': np.linspace(10,50,11),
}

clf = GridSearchCV(model, parameter_space, cv = 5, scoring = "accuracy", verbose = True) # model

clf.fit(X_train,y_train)

train_pred = clf.predict(X_train)   # Train predict
test_pred = clf.predict(X_test)     # Test predict

# Load packages
import pandas as pd
from sklearn import tree
from dtreeviz.trees import dtreeviz # will be used for tree visualization
from matplotlib import pyplot as plt
plt.rcParams.update({'figure.figsize': (12.0, 8.0)})
plt.rcParams.update({'font.size': 14})
 
plt.figure(figsize=(20,20))
_ = tree.plot_tree(clf.n_estimators_[0], feature_names=X_train.columns, filled=True)

Solution

  • You need to select the best Random Forest model from the grid search. You need to change your last line of code :

    _ = tree.plot_tree(clf.best_estimator_.estimators_[0], feature_names=X_train.columns, filled=True)