Search code examples
python-3.xscikit-learndecision-treegraph-visualizationpydot

Color of the node of tree with graphviz using class_names


Expanding on a prior question: Changing colors for decision tree plot created using export graphviz

How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf, out_file=None,
                                feature_names=iris.feature_names,
                                class_names=iris.target_names,
                                filled=True, rounded=True,
                                special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

Solution

  • The code from the example looks so familiar and is therefore easy to modify :)

    For each node Graphviz tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.

    values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
    

    Alternatively you can map the GraphViz nodes back to the sklearn nodes:

    values = clf.tree_.value[int(node.get_name())][0]
    

    We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.

    values = [int(255 * v / sum(values)) for v in values]
    color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
    

    enter image description here

    We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.


    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf,
                                    feature_names=iris.feature_names,
                                    out_file=None,
                                    filled=True,
                                    rounded=True,
                                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    nodes = graph.get_node_list()
    
    for node in nodes:
        if node.get_label():
            values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
            values = [int(255 * v / sum(values)) for v in values]
            color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
            node.set_fillcolor(color)
    
    graph.write_png('colored_tree.png')
    

    A general solution for more than 3 classes which colors only the final nodes .

    colors =  ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')
    
    for node in nodes:
        if node.get_name() not in ('node', 'edge'):
            values = clf.tree_.value[int(node.get_name())][0]
            #color only nodes where only one class is present
            if max(values) == sum(values):    
                node.set_fillcolor(colors[numpy.argmax(values)])
            #mixed nodes get the default color
            else:
                node.set_fillcolor(colors[-1])
    

    enter image description here