Search code examples
pythoncatboost

Visualizing Catboost plot_tree in Python


I'm trying to visualize my Catboost model in Python with the code:

model_CBC.plot_tree(tree_idx=0, pool=pool)
plt.show()

I am getting the output of the rest of the code but I cannot see any tree. Process finishes like this:

MetricVisualizer(layout=Layout(align_self='stretch', height='500px'))
Learning rate set to 0.084924
0:  learn: 1.0230107    total: 111ms    remaining: 1m 51s
1:  learn: 0.9612983    total: 158ms    remaining: 1m 18s
.
.
.
998:    learn: 0.2291117    total: 45s  remaining: 45.1ms
999:    learn: 0.2290360    total: 45.1s    remaining: 0us
Accuracy: 84.90%

Process finished with exit code 0

Any suggestions to how to see the tree?

My catboost method code is:

import pandas as pd
import numpy as np
import xgboost as xgb
import catboost as ctb
import lightgbm as lgb
from xgboost import plot_tree
from lightgbm import plot_tree as lgbm_tree

import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer


# Load the data from the CSV file
data = pd.read_excel('/ileti.xlsx')

# Define the feature and target data
features = data["ileti"]
target = data["label"]

# Split the data into training and testing sets
train_features, test_features, train_target, test_target = train_test_split(features, target, test_size=0.2, random_state=42)


def catboost():
    global train_features, test_features
    vectorizer = TfidfVectorizer()
    train_features = vectorizer.fit_transform(train_features)
    test_features = vectorizer.transform(test_features)
    pool = ctb.Pool(train_features, train_target)
    model_CBC = ctb.CatBoostClassifier().fit(pool, plot=True)
    model_CBC.plot_tree(tree_idx=0, pool=pool)
    plt.show()
    #model_CBC.fit(pool, plot=True)
    #print(model_CBC)
    expected_y = test_target
    predicted_y = model_CBC.predict(test_features)
    accuracy = accuracy_score(expected_y, predicted_y)
    print("Accuracy: %.2f%%" % (accuracy * 100.0))



catboost()

I tried to look up to internet for the similar problem but couldn't find any solutions. I install graphiz through brew and pip and it is working fine in visualizing the xgboost plot-tree and lightgbm plot-tree.


Solution

  • Problem is solved by saving the tree to a variable and calling the render method.

    a = model_CBC.plot_tree(tree_idx=4)
    a.render()
    

    This creates a pdf file with the graph in it.