Search code examples
pythonlightgbm

LGBMClassifier create_tree_diagraph not showing in notebook


I am using LightGBM's LGBMClassifier for a binary classification problem and want to print out the actual diagram.

Here is how I trained/fit the model

clf = lgb.LGBMClassifier()
clf.fit(x_train, y_train, categorical_feature = x_train.select_dtypes(include = 'category').columns.tolist())

And here is how I am trying to print the diagram

lgb.create_tree_digraph(clf, orientation='vertical')

However, the only output I am getting is

<graphviz.graphs.Digraph at 0x7f6a10a5ed00>

I also tried using the parent lightgbm.train() method to build the model and as the booster argument in create_tree_diagraph, however I am getting similar output.

Is there an additional library or function I have to call to print out the tree, or is there another way to perhaps save it to a .png file?

I am using a Python Notebook in Databricks.


Solution

  • The REPL explained to you that you have a Digraph in hand. Good. Let's assign it to a temp var, and render it.

    https://graphviz.readthedocs.io/en/stable/manual.html#basic-usage

    g = lgb.create_tree_digraph( ... )
    
    g.render(view=True)
    
    g.format = "png"
    g.render("output.png")
    
    g.view()
    

    This relies on dot being already installed so it works when bash invokes it.