Search code examples
scikit-learnrandom-forestsklearn-pandas

sklearn random forest plot interpretation


can you please help me to understand the plot below. what is Gini? what the meaning that the values of the Glucose are [66,72]? what the diffrent betweeen the colors (blue,white,pink)?

data based on diabetes.csv (google it)

    from matplotlib import pyplot as plt
    import pandas as pd
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    
    diab_cols = ['Pregnancies', 'Insulin', 'BMI', 'Age','Glucose','BloodPressure','DiabetesPedigreeFunction'] 
    X = df[diab_cols]# Features 
    y = df.Outcome # Target variable 
    
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, 
        test_size=0.25, 
        random_state=0)
    
    model = RandomForestClassifier(n_estimators=100,max_depth=5).fit(X_train,y_train)  
    plt.figure(figsize=(20,20))
    _ = tree.plot_tree(model.estimators_[0], feature_names=X.columns, filled=True)

enter image description here


Solution

  • From this example:

    For each pair of features, the decision tree learns decision boundaries made of combinations of simple thresholding rules inferred from the training samples.

    See also this example.


    edit:

    • The Gini index (or impurity) is a value computed by the decision tree. It is derived by subtracting the total of the squared probabilities of each class from one and multiplying the result by 100. It corresponds to the likelihood that the target would be categorized wrongly when a random sample is chosen. The smaller is the value, the more confident is the prediction.

    • The values are the number of samples of each class in the training dataset.

    • The color encodes what the predicted class at each node is.