Search code examples
pythonscikit-learncartgraphvizdecision-tree

Changing colors for decision tree plot created using export graphviz


I am using scikit's regression tree function and graphviz to generate the wonderful, easy to interpret visuals of some decision trees:

dot_data = tree.export_graphviz(Run.reg, out_file=None, 
                         feature_names=Xvar,  
                         filled=True, rounded=True,  
                         special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('CART.png')
graph.write_svg("CART.svg")

enter image description here

This runs perfectly, but I'd like to change the color scheme if possible? The plot represents CO2 fluxes, so I'd like to make the negative values green and positive brown. I can export as svg instead and alter everything manually, but when I do, the text doesn't quite line up with the boxes so changing the colors manually and fixing all the text adds a very tedious step to my workflow that I would really like to avoid! enter image description here

Also, I've seen some trees where the length of the lines connecting nodes is proportional to the % variance explained by the split. I'd love to be able to do that too if possible?


Solution

    • You can get a list of all the edges via graph.get_edge_list()
    • Each source node should have two target nodes, the one with the lower index is evaluated as True, the higher index as False
    • Colors can be assigned via set_fillcolor()

    enter image description here

    import pydotplus
    from sklearn.datasets import load_iris
    from sklearn import tree
    import collections
    
    clf = tree.DecisionTreeClassifier(random_state=42)
    iris = load_iris()
    
    clf = clf.fit(iris.data, iris.target)
    
    dot_data = tree.export_graphviz(clf,
                                    feature_names=iris.feature_names,
                                    out_file=None,
                                    filled=True,
                                    rounded=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    
    colors = ('brown', 'forestgreen')
    edges = collections.defaultdict(list)
    
    for edge in graph.get_edge_list():
        edges[edge.get_source()].append(int(edge.get_destination()))
    
    for edge in edges:
        edges[edge].sort()    
        for i in range(2):
            dest = graph.get_node(str(edges[edge][i]))[0]
            dest.set_fillcolor(colors[i])
    
    graph.write_png('tree.png')
    

    Also, i've seen some trees where the length of the lines connecting nodes is proportional to the % varriance explained by the split. I'd love to be able to do that too if possible!?

    You could play with set_weight() and set_len() but that's a bit more tricky and needs some fiddling to get it right but here is some code to get you started.

    for edge in edges:
        edges[edge].sort()
        src = graph.get_node(edge)[0]
        total_weight = int(src.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
        for i in range(2):
            dest = graph.get_node(str(edges[edge][i]))[0]
            weight = int(dest.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
            graph.get_edge(edge, str(edges[edge][0]))[0].set_weight((1 - weight / total_weight) * 100)
            graph.get_edge(edge, str(edges[edge][0]))[0].set_len(weight / total_weight)
            graph.get_edge(edge, str(edges[edge][0]))[0].set_minlen(weight / total_weight)