Search code examples
pythontreescikit-learndata-analysis

How to make Python decision tree more understandable?


I have a data file. The last column of the data has +1 and -1 distinguishing variables. I also have the id names of each column in a separate file.

e.g.

1 2 3 4 1
5 6 7 8 1
9 1 2 3 -1
4 5 6 7 -1
8 9 1 2 -1

and for each column I have Q1, Q2, Q3, Q4, Q5 names respectively.

I want to implement decision tree classifier so I wrote the following code:

import numpy
from sklearn import tree

print('Reading data from ' + fileName);
data = numpy.loadtxt(fileName);
print('Getting ids from ', idFile)
idArray = numpy.genfromtxt('cleanedID.csv', dtype='str')

print('Adding ids')
print('data dimensions: ', data.shape)
print('idArray dimensions: ', idArray.shape)
data = numpy.append(idArray, data, axis = 0)


y = data[:,-1]
x = data[:, 1:-1]

classifier = tree.DecisionTreeClassifier(max_depth = depth)
classifier = classifier.fit(x, y)

with open('graph.dot', 'w') as file:
    tree.export_graphviz(classifier, out_file = file)

file.close()

I used graphviz to convert .dot file to .png file.

The problem is that the decision tree which looks something like: enter image description here

I don't get what X[number] means. I searched and found that value = [5 0] means class 5 has 0 objects and class 0 has 5 objects but I have only +1 and -1 distinguishing variables. Is there anyway I can tweak this decision tree so that I can see the column names (Q1, Q2, Q3....) in the decision tree picture so I can understand that what this means?

Thanks


Solution

  • Value = [5 0] means that the first class has 5 members and the second class has 0 members. For you, the class order is probably [-1 1].

    As for column names: As yangjie pointed out, X[158] means the 159th column (zero-indexing). The rule is pretty spelled out already: X[168]<=1.5 means for a given row, the tree is deciding whether to go left or right based on the value of the 168th column and how it compares to 1.5.

    You can add column names using the feature_names optional argument to export_graphviz