Search code examples
pythontensorflowmachine-learningrandom-foresttensorflow-decision-forests

tensorflow random forest ploting errors


running jupyter on anaconda mac/m2

after fitting the training data

rf = tfdf.keras.RandomForestModel(task = tfdf.keras.Task.REGRESSION)
rf.compile(metrics=["mse"])

rf.fit(x=train_ds)

i want to vizualise the model with the following code, but nothing is displayed

tfdf.model_plotter.plot_model_in_colab(rf, tree_idx=0, max_depth=3)

can i please have a suggestion or recommendation about what to do?

yep!(i tried chatgpt) it wrote the same code several times or a variation and still nothing.

according to chatgpt i have all the dependences installed


Solution

  • TF-DF author here.

    Unfortunately, interactive plotting with TF-DF only works in Colab, not in IPython, since the two have slightly different Javascript integrations. Currently, you have two options:

    1. Use non-interactive text plots:
    > print(model_1.make_inspector().extract_tree(1))
    (bill_depth_mm >= 16.350000381469727; miss=True, score=0.4877108931541443)
        ├─(pos)─ (bill_length_mm >= 43.05000305175781; miss=True, score=0.4372641444206238)
        │        ├─(pos)─ (body_mass_g >= 4125.0; miss=True, score=0.52157062292099)
        │        │        ├─(pos)─ (flipper_length_mm >= 199.01458740234375; miss=True, score=0.5047621130943298)
        │        │        │    ...
        │        │        └─(neg)─ ProbabilityValue([0.0, 0.0, 1.0],n=38.0) (idx=5)
        │        └─(neg)─ (bill_depth_mm >= 17.450000762939453; miss=False, score=0.015847451984882355)
        │                 ├─(pos)─ ProbabilityValue([1.0, 0.0, 0.0],n=68.0) (idx=4)
        │                 └─(neg)─ (bill_length_mm >= 38.900001525878906; miss=True, score=0.0711795762181282)
        │                      ...
        └─(neg)─ (body_mass_g >= 3750.0; miss=True, score=0.20150887966156006)
                 ├─(pos)─ ProbabilityValue([0.0, 1.0, 0.0],n=93.0) (idx=1)
                 └─(neg)─ ProbabilityValue([1.0, 0.0, 0.0],n=5.0) (idx=0)
    
    1. If you want beautiful visualizations with lots of options and lots of information, you can use dtreeviz. There is a tutorial on the TensorFlow website that explains in detail how to use it with TF-DF

    2. Extract the HTML that TF-DF produces yourself and use it in a compatible viewer:

    html = tfdf.model_plotter.plot_model(rf, tree_idx=0, max_depth=3)