Search code examples
pythonscikit-learndecision-tree

How can I walk a specific path in an sklearn DecisionTreeClassifier


Bit of a rookie here, and am having trouble getting the solution:

Here's my Decision tree: enter image description here

here's the output from print(dt.tree_.__getstate__())

{'max_depth': 3, 'node_count': 7,
'nodes': array([( 1,  6, 19,  0.5, 0.41848786, 1382, 1382.),
( 2,  5, 12,  0.5, 0.49534472,  912,  912.),
( 3,  4,  3,  0.5, 0.43366519,  604,  604.),
(-1, -1, -2, -2. , 0.33618881,  449,  449.),
(-1, -1, -2, -2. , 0.47150884,  155,  155.),
(-1, -1, -2, -2. , 0.        ,  308,  308.),
(-1, 1, -2, -2. , 0.        ,  470,  470.)],
dtype=[('left_child', '<i8'), ('right_child', '<i8'), ('feature', '<i8'), ('threshold', '<f8'), ('impurity', '<f8'), ('n_node_samples', '<i8'), ('weighted_n_node_samples', '<f8')]), 'values': array([[[970., 412.]],
[[500., 412.]],
[[192., 412.]],
[[ 96., 353.]],
[[ 96.,  59.]],
[[308.,   0.]],
[[470.,   0.]]])}

(I've taken the liberty of reformatting a bit, because SO kept complaining).

Here's what I have been tasked to do:

  • Identify the final outcome (class) of the decision tree for the sample with:
buying_vhigh = 1
persons_2 = 0
safety_low = 0

My question is: Is there a way to do this in code (python)? Obviously it's trivial by inspection, but I want to be able to get the machine to do it.


Solution

  • One way to do it is to put your values into a dictionary, which you put in turn into a DataFrame. By doing so, the variables are named and they do not have to be in the right order. But the crucial part for getting the class is to use dt.predict():

    import pandas as pd
    
    # Define the given sample with feature names
    sample = {
        'buying_vhigh': 1,
        'persons_2': 0,
        'safety_low': 0
    }
    
    # Create a pandas DataFrame with the given sample
    sample_df = pd.DataFrame([sample])
    
    # Use the predict method to make a prediction for the given sample
    prediction = dt.predict(sample_df)
    
    # Print the predicted class
    print("The predicted class is:", prediction[0])