I am about to visualize a causal decision tree based on my model. I finally made it to actually plot the tree but somehow it doesnt show my variables names but X[5]
on the nodes name. As soon as I add feature_names=X
it does not plot anymore but shows Key Error 7
**
Can someone help out?
Many thanks
Elisa
import numpy as np
import pandas as pd
import graphviz
from econml.dml import CausalForestDML
from econml.cate_interpreter import SingleTreeCateInterpreter
#load cdv
data = pd.read_csv("Basis_Entscheidungsbaum.csv", sep=";", header=0)
#Variables
feature_names=['DL', 'KE', 'AA', 'K', 'ST', 'G', 'BV', 'A']
Y = data['Z']
T = data['M']
X = data[feature_names]
#tree model
tree_model = CausalForestDML(n_estimators=1, subforest_size=1, inference=False, max_depth=4)
#causal decision tree
tree_model = tree_model.fit(Y=Y, X=X , T=T)
intrp = SingleTreeCateInterpreter(max_depth=3).interpret(tree_model, X)
#Visualization
intrp.plot(fontsize=12)
# intrp.plot(feature_names=X, fontsize=12)
I expect variables names on the node
Input:
A | BV | G | ST | K | AA | KE | DL | Z | M |
---|---|---|---|---|---|---|---|---|---|
32 | 14 | 3400 | 1 | 0 | 7,49 | 2 | 3 | 0 | 1 |
29 | 5 | 2900 | 4 | 0 | 8,21 | 0 | 2 | 1 | 1 |
44 | 19 | 7400 | 5 | 2 | 9,39 | 0 | 4 | 0 | 0 |
Accoring to SingleTreeCateInterpreter
's documentation, if you want the feature names, the method plot
expects a list of strings:
plot(ax=None, title=None, feature_names=None, treatment_names=None, max_depth=None, filled=True, rounded=True, precision=3, fontsize=None)
- feature_names (list of str, optional) – Names of each of the features.
Since you already generated this list earlier, you can just pass it:
intrp.plot(feature_names=feature_names, fontsize=12)
This is the code which I tried to reproduce the error but is running fine on my machine:
import pandas as pd
import numpy as np
import graphviz
from econml.dml import CausalForestDML
from econml.cate_interpreter import SingleTreeCateInterpreter
%matplotlib inline
data = {
'A': [32, 29, 44],
'BV': [14, 5, 19],
'G': [3400, 2900, 7400],
'ST': [1, 4, 5],
'K': [0, 0, 2],
'AA': ['7,49', '8,21', '9,39'],
'KE': [2, 0, 0],
'DL': [3, 2, 4],
'Z': [0, 1, 0],
'M': [1, 1, 0]
}
# Replace decimal
data['AA'] = [float(x.replace(',', '.')) for x in data['AA']]
# Define the number of additional rows to generate
n = 200
# Generate random values for each column and append to the data dictionary
for i in range(n):
data['A'].append(np.random.randint(20, 50))
data['BV'].append(np.random.randint(5, 20))
data['G'].append(np.random.randint(2000, 8000))
data['ST'].append(np.random.randint(1, 6))
data['K'].append(np.random.randint(0, 3))
data['AA'].append(round(np.random.uniform(5, 10),2))
data['KE'].append(np.random.randint(0, 3))
data['DL'].append(np.random.randint(2, 5))
data['Z'].append(np.random.randint(0, 2))
data['M'].append(np.random.randint(0, 2))
data = pd.DataFrame(data)
#Variables
feature_names=['DL', 'KE', 'AA', 'K', 'ST', 'G', 'BV', 'A']
Y = data['Z']
T = data['M']
X = data[feature_names]
#tree model
tree_model = CausalForestDML(n_estimators=1, subforest_size=1, inference=False, max_depth=4)
#causal decision tree
tree_model = tree_model.fit(Y=Y, X=X , T=T)
intrp = SingleTreeCateInterpreter(max_depth=3).interpret(tree_model, X)
#Visualization
intrp.plot(feature_names=feature_names, fontsize=12)
Output:
In order to avoid running into InvalidParameterError: The 'criterion' parameter of DecisionTreeRegressor must be a str among {'poisson', 'squared_error', 'friedman_mse', 'absolute_error'}. Got 'mse' instead.
, you can use sklearn<1.2
.