Search code examples
python-3.xscikit-learndecision-treegrid-search

GridSearchCV is not fitted yet error when using export_graphiz despite having fitted it


So I trained a Decision Tree classifier model and I am using the GridSearchCV output to plot the tree plot. Here is my code for the decision tree model:

from sklearn.tree import DecisionTreeClassifier

# function for decision tree using grid search cross validation to find optimal hyperparameters
def decisionTree(X_train, X_test, y_train, y_test, cv=10):
    dt = DecisionTreeClassifier(random_state=0)
    params = {'max_depth': [3,4,5,6], 'min_samples_leaf': [0.02,0.03,0.04,0.06,0.08,0.3,0.4]}
    dt_clf = GridSearchCV(estimator=dt, param_grid=params, scoring='f1', cv=cv, return_train_score=True)
    dt_clf.fit(X_train,y_train)
    print("Best parameters set found on Cross Validation:\n\n", dt_clf.best_params_)
    print("\nCross Validation F1 score:\n\n", dt_clf.best_score_.round(3))

    # predict test set  
    y_pred = dt_clf.predict(X_test)
    print('\nTest set scores:')
    return {'Accuracy': accuracy_score(y_test, y_pred).round(3), 'Precision': precision_score(y_test, y_pred).round(3), 
        'Recall': recall_score(y_test, y_pred).round(3), 'F1': f1_score(y_test, y_pred).round(3)}, dt_clf

scores, dt_clf = decisionTree(X_train, X_test, y_train, y_test, cv=10) 
scores

Results:

Best parameters set found on Cross Validation:

 {'max_depth': 5, 'min_samples_leaf': 0.03}

Cross Validation F1 score:

 0.753

Test set scores:
{'Accuracy': 0.833, 'Precision': 0.793, 'Recall': 0.852, 'F1': 0.821}

However, when I use graphiz_export, it says that the GridSearchCV is not fitted yet:

from sklearn.tree import export_graphviz 

dot_data = export_graphviz(dt_clf,
feature_names=list(X_train.columns),
class_names=['No Heart Disease', 'Heart Disease'],
out_file=None,
filled=True,
rounded=True,
special_characters=True)

NotFittedError: This GridSearchCV instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.

I cannot seem to figure out where is the problem!


Solution

  • 1. You have missed something elsewhere cause the object is indeed fitted. To check that use check_is_fitted().

    2.You need to pass the best estimator to the export_graphviz()and not the Gridsearch i.e. export_graphviz(dt_clf.best_estimator_)


    Example:

    from sklearn.tree import DecisionTreeClassifier
    from sklearn.datasets import load_iris
    from sklearn.model_selection import train_test_split
    import pandas as pd, numpy as np
    from sklearn.model_selection import GridSearchCV
    from sklearn.metrics import accuracy_score,precision_score, recall_score, f1_score
    from sklearn.tree import export_graphviz 
    from sklearn.utils.validation import check_is_fitted
    
    data = load_iris()
    
    X = data.data
    y=data.target
    tmp = y!=2
    y = y[tmp]
    X = X[tmp, :]
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
    
    # function for decision tree using grid search cross validation to find optimal hyperparameters
    def decisionTree(X_train, X_test, y_train, y_test, cv=10):
        dt = DecisionTreeClassifier(random_state=0)
        params = {'max_depth': [3,4,5,6], 'min_samples_leaf': [0.02,0.03,0.04,0.06,0.08,0.3,0.4]}
        dt_clf = GridSearchCV(estimator=dt, param_grid=params, scoring='f1', cv=cv, return_train_score=True)
        dt_clf.fit(X_train,y_train)
        print("Best parameters set found on Cross Validation:\n\n", dt_clf.best_params_)
        print("\nCross Validation F1 score:\n\n", dt_clf.best_score_.round(3))
    
        # predict test set  
        y_pred = dt_clf.predict(X_test)
        print('\nTest set scores:')
        return {'Accuracy': accuracy_score(y_test, y_pred).round(3), 'Precision': precision_score(y_test, y_pred).round(3), 
            'Recall': recall_score(y_test, y_pred).round(3), 'F1': f1_score(y_test, y_pred).round(3)}, dt_clf
    
    scores, dt_clf = decisionTree(X_train, X_test, y_train, y_test, cv=10) 
    
    check_is_fitted(dt_clf)
    # this passes
    
    dot_data = export_graphviz(dt_clf.best_estimator_,
    feature_names=list(data.feature_names),
    class_names=['No Heart Disease', 'Heart Disease'],
    out_file=None,
    filled=True,
    rounded=True,
    special_characters=True)