Search code examples
pythonmachine-learningscikit-learnclassificationdecision-tree

Plot a decision tree from HistGratientBoostingClassifier


I have a HistGradientBoostingClassifier model and I want to plot one or more of its decision trees, nevertheless I can't manage to find a native function to do it, I can access the Tree predictor objects and thus it's nodes, but in order to plot it into the sklearn.tree.plot_tree function it needs to be a DecisionTree type object

I tried this:

from sklearn.tree import plot_tree

plot_tree(RF_90._predictors[0][0])

getting this error:

InvalidParameterError: The 'decision_tree' parameter of plot_tree must be an instance of 'sklearn.tree._classes.DecisionTreeClassifier' or an instance of 'sklearn.tree._classes.DecisionTreeRegressor'. Got <sklearn.ensemble._hist_gradient_boosting.predictor.TreePredictor object at 0x7f676ebf0310> instead.

Note: RF_90 is the HistGradientBoostingClassifier fitted model


Solution

  • In order to visualize trees generated by HistGradientboostingClassifier this function worked for me:

    def visualize_tree(tree, feature_names, class_names): 
    dot = graphviz.Digraph() 
    def add_nodes_edges(dot, nodes, node_id):
        node = nodes[node_id]
        if node['is_leaf']: 
            value = node['value'] 
            dot.node(str(node_id), f"Predict: {value}") 
        else: 
            feature = feature_names[node['feature_idx']] 
            threshold = node['bin_threshold'] 
            dot.node(str(node_id), f"{feature} <= {threshold:.2f}") 
            left_child = node['left'] 
            right_child = node['right']
            dot.edge(str(node_id), str(left_child), "True") 
            dot.edge(str(node_id), str(right_child), "False") 
            add_nodes_edges(dot, nodes, left_child) 
            add_nodes_edges(dot, nodes, right_child) 
    nodes = tree.__getstate__()['nodes'] 
    add_nodes_edges(dot, nodes, 0) 
    return dot 
    
    # Create and visualize the tree 
    dot = visualize_tree(single_tree, RF_90.feature_names_in_, 1) 
    dot.render("hist_gb_tree") # Save to file 
    dot #view from jupyter
    

    where RF_90 is the fitted model and single_tree is:

    single_tree = trees_per_iteration[iteration][class_index]
    

    where iteration = 0 and class_index = 0