Search code examples
pandasscikit-learndecision-tree

Verifying the decision tree graph


I created the decision tree model in following manner.

# first create the model
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
from IPython.display import Image  
from sklearn import tree
import pydotplus
import pandas as pd
iris = datasets.load_iris()

X = iris.data
y = iris.target
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X, y);

X = pd.DataFrame(X, columns=["sepal_length","sepal_width","petal_length","petal_width"])
# converted to data frame for easy analysis

Then plotted the graph in the following manner

import pydotplus
# Create DOT data
dot_data = tree.export_graphviz(clf, out_file=None, feature_names=X.columns, 
class_names = iris.target_names)
# Draw graph
graph = pydotplus.graph_from_dot_data(dot_data)  
# Show graph
Image(graph.create_png())

I have found the following result

enter image description here

I took the subset of data at the stage 3a .

X3a = X.query("petal_width >.8 and petal_width <=1.75")

And created a function for finding gini index for each column.

 def gini2(x):
        # (Warning: This is a concise implementation, but it is O(n**2)
        # in time and memory, where n = len(x).  *Don't* pass in huge
        # samples!)
        # Mean absolute difference
        mad = np.abs(np.subtract.outer(x, x)).mean()
        # Relative mean absolute difference
        rmad = mad/np.mean(x)
        # Gini coefficient
        g = 0.5 * rmad
        return g

Finally verified the gini index for each column of data at stage 3a

gini2( X3a["sepal_length"] ) # returns 0.051
gini2( X3a["sepal_width"] ) # returns 0.063
gini2( X3a["petal_length"] ) # returns 0.0686
gini2( X3a["petal_width"] ) # returns 0.08, highest among all the columns

I found that the highest gini index is for petal_width (0.08). So I expected that the split at this stage will be on petal_width. But the picture shows that split is on petal_length. Can someone explain why petal_length is taken (for split) rather than petal_width?


Solution

  • Finally I found answer to the question.

    "Split is made based on maximum information gain. At stage 3a (as shown above), the split on petal_length yields maximum information gain (even though petal_length does not have highest gini value among the columns) "