Search code examples
plotgraphscikit-learnmodelstringio

NotFittedError: This BalancedRandomForestClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this method


I am trinyg to plotting my model but the code is error which model is not fitted yet. But i fitted the model. Can someone help me to why i'm getting this error?

My code is below;

model = BalancedRandomForestClassifier(n_estimators = 200, random_state = 0, max_depth=6)

model.fit(x_train, y_train)
y_pred_rf = model.predict(x_test)

dot_data = StringIO()
export_graphviz(model, out_file=dot_data,  
            filled=True, rounded=True,
            special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())

Error is below;

---------------------------------------------------------------------------
NotFittedError                            Traceback (most recent call last)
<ipython-input-57-0036434b9b2c> in <module>
 16 export_graphviz(model, out_file=dot_data,  
 17             filled=True, rounded=True,
 ---> 18             special_characters=True)
 19 graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
 20 Image(graph.create_png())

 /opt/anaconda/envs/env_python/lib/python3.6/site- 
 packages/sklearn/tree/export.py in export_graphviz(decision_tree, out_file, 
 max_depth, feature_names, class_names, label, filled, leaves_parallel, 
 impurity, node_ids, proportion, rotate, rounded, special_characters, 
 precision)
 754     """
 755 
 --> 756     check_is_fitted(decision_tree, 'tree_')
 757     own_file = False
 758     return_string = False

 /opt/anaconda/envs/env_python/lib/python3.6/site- 
 packages/sklearn/utils/validation.py in check_is_fitted(estimator, 
 attributes, msg, all_or_any)
 912 
 913     if not all_or_any([hasattr(estimator, attr) for attr in 
 attributes]):
 --> 914         raise NotFittedError(msg % {'name': 
 type(estimator).__name__})
 915 
 916 

 NotFittedError: This BalancedRandomForestClassifier instance is not fitted 
 yet. Call 'fit' with appropriate arguments before using this method.

Solution

  • export_graphviz expects a single tree model. Therefore, you need to loop over the tree when this is an ensemble. BalancedRandomForestClassifier exposed estimators_ for such usage. As mentioned by Parthasarathy Subburaj, you can loop other the estimators_ and call the function.

    This said I would advise you to use sklearn.tree.plot_tree(...): https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html

    It is a pure matplotlib plotting helper and makes it easier if you only want to get an image representation without using graphviz