Search code examples
scikit-learndecision-tree

sklearn plot_tree function does not show the class when the tree has only one node


I use the following code to plot a decisions trees:

    plt.figure(figsize=(12, 12))
    plot_tree(estimator, 
            feature_names=feature_names, 
            label= 'all',
            class_names=[f'Class {k}' for k in range(2)], 
            filled=True, 
            rounded=True,
            impurity = True
            )
    plt.title(f"Decision Tree for Effect {effect_name}", fontsize = 40)

    file_name = f"DTs_action_{str(num_action)}/decision_tree_effect_{effect_name}.png"
    plt.savefig(file_name, format="png", dpi=300)  

When the tree has more than one node, I obtain something like:

enter image description here

But when it's a tree with one only node I get:

enter image description here

Why don't I see the class of the node for the case with one node only ?

Thanks


Solution

  • The answer is that there is only one class in this tree, so no class name is displayed.

    You can verify this in the source code: https://github.com/scikit-learn/scikit-learn/blob/160fe6719a1f44608159b0999dea0e52a83e0963/sklearn/tree/_export.py#L377