I want to make my trees simpler, wondering to plot trees without showing samples (e.g. 83) and values (e.g.[34,53,29,26])? (I don't want the last two lines)
Here is a part of the current code of plotting trees.
X = df.iloc[:,0: -1]
y = df.iloc[:,-1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y)
clf = RandomForestClassifier()
clf.fit(X_train,y_train)
.
.
.
.
# Here, I guess I need to add some commands.
plot_tree(clf.estimators_[5],
feature_names=X.columns,
class_names=names,
filled=True,
impurity=True,
rounded=True,
max_depth = 3)
Let's say we have a dataset like this, and we assign the matplotlib axis using ax =
argument:
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
import matplotlib.pyplot as plt
import re
import matplotlib
fig, ax = plt.subplots(figsize=(8,5))
clf = RandomForestClassifier(random_state=0)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
tree.plot_tree(clf.estimators_[0],ax=ax,
feature_names= iris.feature_names, class_names=iris.target_names)
Not sure if it is the best way, one way is to go under ax.properties()
and edit the text:
def replace_text(obj):
if type(obj) == matplotlib.text.Annotation:
txt = obj.get_text()
txt = re.sub("\nsamples[^$]*class","\nclass",txt)
obj.set_text(txt)
return obj
ax.properties()['children'] = [replace_text(i) for i in ax.properties()['children']]
fig.show()