Search code examples
pythonvisualizationdecision-tree

How do I get my variables on a node in decision tree? I get Key Error 7


I am about to visualize a causal decision tree based on my model. I finally made it to actually plot the tree but somehow it doesnt show my variables names but X[5] on the nodes name. As soon as I add feature_names=X it does not plot anymore but shows Key Error 7 ** Can someone help out? Many thanks Elisa

import numpy as np
import pandas as pd
import graphviz
from econml.dml import CausalForestDML
from econml.cate_interpreter import SingleTreeCateInterpreter


#load cdv
data = pd.read_csv("Basis_Entscheidungsbaum.csv", sep=";", header=0)

#Variables 
feature_names=['DL', 'KE', 'AA', 'K', 'ST', 'G', 'BV', 'A']

Y = data['Z']
T = data['M']
X = data[feature_names]

#tree model
tree_model = CausalForestDML(n_estimators=1, subforest_size=1, inference=False, max_depth=4)

#causal decision tree
tree_model = tree_model.fit(Y=Y, X=X , T=T)
intrp = SingleTreeCateInterpreter(max_depth=3).interpret(tree_model, X)

#Visualization
intrp.plot(fontsize=12)

# intrp.plot(feature_names=X, fontsize=12)

I expect variables names on the node

Input:

A BV G ST K AA KE DL Z M
32 14 3400 1 0 7,49 2 3 0 1
29 5 2900 4 0 8,21 0 2 1 1
44 19 7400 5 2 9,39 0 4 0 0

Solution

  • Accoring to SingleTreeCateInterpreter's documentation, if you want the feature names, the method plot expects a list of strings:

    plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)

    • feature_names (list of str, optional) – Names of each of the features.

    Since you already generated this list earlier, you can just pass it:

    intrp.plot(feature_names=feature_names, fontsize=12)
    

    Edit

    This is the code which I tried to reproduce the error but is running fine on my machine:

    import pandas as pd
    import numpy as np
    import graphviz
    from econml.dml import CausalForestDML
    from econml.cate_interpreter import SingleTreeCateInterpreter
    
    %matplotlib inline
    
    data = {
        'A': [32, 29, 44],
        'BV': [14, 5, 19],
        'G': [3400, 2900, 7400],
        'ST': [1, 4, 5],
        'K': [0, 0, 2],
        'AA': ['7,49', '8,21', '9,39'],
        'KE': [2, 0, 0],
        'DL': [3, 2, 4],
        'Z': [0, 1, 0],
        'M': [1, 1, 0]
    }
    
    # Replace decimal
    data['AA'] = [float(x.replace(',', '.')) for x in data['AA']]
    
    # Define the number of additional rows to generate
    n = 200
    
    # Generate random values for each column and append to the data dictionary
    for i in range(n):
        data['A'].append(np.random.randint(20, 50))
        data['BV'].append(np.random.randint(5, 20))
        data['G'].append(np.random.randint(2000, 8000))
        data['ST'].append(np.random.randint(1, 6))
        data['K'].append(np.random.randint(0, 3))
        data['AA'].append(round(np.random.uniform(5, 10),2))
        data['KE'].append(np.random.randint(0, 3))
        data['DL'].append(np.random.randint(2, 5))
        data['Z'].append(np.random.randint(0, 2))
        data['M'].append(np.random.randint(0, 2))
    
    data = pd.DataFrame(data)
    
    #Variables 
    feature_names=['DL', 'KE', 'AA', 'K', 'ST', 'G', 'BV', 'A']
    
    Y = data['Z']
    T = data['M']
    X = data[feature_names]
    
    #tree model
    tree_model = CausalForestDML(n_estimators=1, subforest_size=1, inference=False, max_depth=4)
    
    #causal decision tree
    tree_model = tree_model.fit(Y=Y, X=X , T=T)
    intrp = SingleTreeCateInterpreter(max_depth=3).interpret(tree_model, X)
    
    #Visualization
    intrp.plot(feature_names=feature_names, fontsize=12)
    

    Output:

    https://imgur.com/a/aaa8Mh3

    Edit 2

    In order to avoid running into InvalidParameterError: The 'criterion' parameter of DecisionTreeRegressor must be a str among {'poisson', 'squared_error', 'friedman_mse', 'absolute_error'}. Got 'mse' instead., you can use sklearn<1.2.