Search code examples
pythonxgboost

How do I include feature names in the plot_tree function from the XGBoost library?


I've been using the XGBoost library to develop a binary classification model. Having trained my model I am interested in visualizing the individual trees to better understand my models predictions.

To do this XGBoost provides a plot_tree function but it only shows the integer index of the feature. Here is an example of one of my trees:

How do I include the feature name in this image rather than feature index (f28)?


Solution

  • The plot_tree function in xgboost has an argument fmap which is a path to a 'feature map' file; this contains a mapping of the feature index to feature name.

    The documentation on the feature map file is sparse, but it is a tab-delimited file where the first column is the feature indices (starting from 0 and ending at the number of features), the second column the feature name and the final column an indicator showing the type of feature (q=quantitative feature, i=binary feature).

    An example of a feature_map.txt file:

    0    feature_name_0    q
    1    feature_name_1    i
    2    feature_name_2    q
    …          …           … 
    

    With this tab-delimited file you can then plot your tree from your trained model instance:

    import xgboost
    model = xgboost.XGBClassifier()
    
    # train the model
    model.fit(X, y)
    
    # plot the decision tree, providing path to feature map file
    
    xgboost.plot_tree(model,  num_trees=0, fmap='feature_map.txt')
    

    Using this function displays the plot: