Search code examples
pythonmachine-learningscikit-learnclassificationdecision-tree

Scikit-Learn Decision Tree: Probability of prediction being a or b?


I have a basic decision tree classifier with Scikit-Learn:

#Used to determine men from women based on height and shoe size

from sklearn import tree

#height and shoe size
X = [[65,9],[67,7],[70,11],[62,6],[60,7],[72,13],[66,10],[67,7.5]]

Y=["male","female","male","female","female","male","male","female"]

#creating a decision tree
clf = tree.DecisionTreeClassifier()

#fitting the data to the tree
clf.fit(X, Y)

#predicting the gender based on a prediction
prediction = clf.predict([68,9])

#print the predicted gender
print(prediction)

When I run the program, it always outputs either "male" or "female", but how would I be able to see the probability of the prediction being male or female? For example, the prediction above returns "male", but how would I get it to print the probability of the prediction being male?

Thanks!


Solution

  • You can do something like the following:

    from sklearn import tree
    
    #load data
    X = [[65,9],[67,7],[70,11],[62,6],[60,7],[72,13],[66,10],[67,7.5]]
    Y=["male","female","male","female","female","male","male","female"]
    
    #build model
    clf = tree.DecisionTreeClassifier()
    
    #fit
    clf.fit(X, Y)
    
    #predict
    prediction = clf.predict([[68,9],[66,9]])
    
    #probabilities
    probs = clf.predict_proba([[68,9],[66,9]])
    
    #print the predicted gender
    print(prediction)
    print(probs)
    

    Theory

    The result of clf.predict_proba(X) is: The predicted class probability which is the fraction of samples of the same class in a leaf.

    Interpretation of the results:

    The first print returns ['male' 'male'] so the data [[68,9],[66,9]] are predicted as males.

    The second print returns:

    [[ 0. 1.] [ 0. 1.]]

    This means that the data were predicted as males and this is reported by the ones in the second column.

    To see the order of the classes use: clf.classes_

    This returns: ['female', 'male']