Search code examples
pythonclassificationrandom-forestdecision-tree

How to plot tree without showing "samples" and "value" in random forest?


I want to make my trees simpler, wondering to plot trees without showing samples (e.g. 83) and values (e.g.[34,53,29,26])? (I don't want the last two lines) enter image description here

Here is a part of the current code of plotting trees.

X = df.iloc[:,0: -1] 
y = df.iloc[:,-1]    
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
clf = RandomForestClassifier()
clf.fit(X_train,y_train)
.
.
.
.
# Here, I guess I need to add some commands.
plot_tree(clf.estimators_[5], 
          feature_names=X.columns,
          class_names=names, 
          filled=True, 
          impurity=True, 
          rounded=True,
          max_depth = 3)

Solution

  • Let's say we have a dataset like this, and we assign the matplotlib axis using ax = argument:

    from sklearn.datasets import load_iris
    from sklearn.ensemble import RandomForestClassifier
    from sklearn import tree
    import matplotlib.pyplot as plt
    import re
    import matplotlib
    
    fig, ax = plt.subplots(figsize=(8,5))
    
    clf = RandomForestClassifier(random_state=0)
    iris = load_iris()
    clf = clf.fit(iris.data, iris.target)
    tree.plot_tree(clf.estimators_[0],ax=ax,
    feature_names= iris.feature_names, class_names=iris.target_names)
    

    Not sure if it is the best way, one way is to go under ax.properties() and edit the text:

    def replace_text(obj):
        if type(obj) == matplotlib.text.Annotation:
            txt = obj.get_text()
            txt = re.sub("\nsamples[^$]*class","\nclass",txt)
            obj.set_text(txt)
        return obj
        
    ax.properties()['children'] = [replace_text(i) for i in ax.properties()['children']]
    fig.show()
    

    enter image description here