Search code examples
pythonscikit-learnxgboostmulticlass-classification

Xgboost: How to convert prediction probabilities to multiclass labels original names?


I am using the xgboost multiclass classifier as outlined in the example below. For each row in the X_test dataframe the model outputs a list with the list elements being the probability corresponding to each category 'a','b','c' or 'd' e.g. [0.44767836 0.2043365 0.15775423 0.19023092].

How can I tell which element in the list corresponds to which class / cateogry (a,b,c or d)? My goal is to create 4 extra columns on the dataframe a,b,c,d with the matching probability as the row value in each column.

import numpy as np
import pandas as pd
import xgboost as xgb
import random
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

#Create Example Data
np.random.seed(312)
data = np.random.random((10000, 3))
y = [random.choice('abcd') for _ in range(data.shape[0])]

features = ["x1", "x2", "x3"]
df = pd.DataFrame(data=data, columns=features)
df['y']=y

#Encode target variable
labelencoder = preprocessing.LabelEncoder()
df['y_target'] = labelencoder.fit_transform(df['y'])
    
#Train Test Split    
X_train, X_test, y_train, y_test = train_test_split(df[features], df['y_target'], test_size=0.2, random_state=42, stratify=y)

#Train Model
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

param = {        'objective':'multi:softprob',
                 'random_state': 20,
                 'tree_method': 'gpu_hist',
                 'num_class':4
                }

xgb_model = xgb.train(param, dtrain, 100)

predictions=xgb_model.predict(dtest)

print(predictions)

Solution

  • Predictions follow the same order as your column labels 0, 1, 2, 3. To get original target names use the classes_ attribute from LabelEncoder.

    import pandas as pd
    
    pd.DataFrame(predictions, columns=labelencoder.classes_)
    >>>     
               a           b           c           d
    0       0.133130    0.214460    0.569207    0.083203
    1       0.232991    0.275813    0.237639    0.253557
    2       0.163103    0.248531    0.114013    0.474352
    3       0.296990    0.202413    0.157542    0.343054
    4       0.199861    0.460732    0.228247    0.111159
    ... 
    1995    0.021859    0.460219    0.235214    0.282708
    1996    0.145394    0.182243    0.225992    0.446370
    1997    0.128586    0.318980    0.237229    0.315205
    1998    0.250899    0.257968    0.274477    0.216657
    1999    0.252377    0.236990    0.221835    0.288798