I am using the xgboost multiclass classifier as outlined in the example below. For each row in the X_test dataframe the model outputs a list with the list elements being the probability corresponding to each category 'a','b','c' or 'd' e.g. [0.44767836 0.2043365 0.15775423 0.19023092]
.
How can I tell which element in the list corresponds to which class / cateogry (a,b,c or d)? My goal is to create 4 extra columns on the dataframe a,b,c,d with the matching probability as the row value in each column.
import numpy as np
import pandas as pd
import xgboost as xgb
import random
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
#Create Example Data
np.random.seed(312)
data = np.random.random((10000, 3))
y = [random.choice('abcd') for _ in range(data.shape[0])]
features = ["x1", "x2", "x3"]
df = pd.DataFrame(data=data, columns=features)
df['y']=y
#Encode target variable
labelencoder = preprocessing.LabelEncoder()
df['y_target'] = labelencoder.fit_transform(df['y'])
#Train Test Split
X_train, X_test, y_train, y_test = train_test_split(df[features], df['y_target'], test_size=0.2, random_state=42, stratify=y)
#Train Model
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
param = { 'objective':'multi:softprob',
'random_state': 20,
'tree_method': 'gpu_hist',
'num_class':4
}
xgb_model = xgb.train(param, dtrain, 100)
predictions=xgb_model.predict(dtest)
print(predictions)
Predictions follow the same order as your column labels 0, 1, 2, 3
. To get original target names use the classes_
attribute from LabelEncoder
.
import pandas as pd
pd.DataFrame(predictions, columns=labelencoder.classes_)
>>>
a b c d
0 0.133130 0.214460 0.569207 0.083203
1 0.232991 0.275813 0.237639 0.253557
2 0.163103 0.248531 0.114013 0.474352
3 0.296990 0.202413 0.157542 0.343054
4 0.199861 0.460732 0.228247 0.111159
...
1995 0.021859 0.460219 0.235214 0.282708
1996 0.145394 0.182243 0.225992 0.446370
1997 0.128586 0.318980 0.237229 0.315205
1998 0.250899 0.257968 0.274477 0.216657
1999 0.252377 0.236990 0.221835 0.288798