Search code examples
pythonmachine-learningxgboostxgbclassifier

Interpreting leaf values of XGBoost trees for multiclass classification problem


I have been using the XGBoost Python library for my multiclass classification problem, with the multi:softmax objective. Generally, I am not sure how to interpret the leaf values of the several decision trees that are outputted when I use xgb.plot_tree(), or when I dump the model into a txt file with bst.dump_model().

My problem has 6 classes, labeled 0-5, and I have set my model to perform two boosting iterations (at least for now as I try to understand the workings of XGBoost a little more). From online searches (in particular https://github.com/dmlc/xgboost/issues/1746), I have noticed that the tree at booster[x] represents the tree in the int(x/(num_classes)) + 1 'th iteration of boosting, showing the decision tree for the x%(num_classes) class. For example, booster[7] in my txt file shows the decision tree during the 2nd iteration of boosting, and for class 1. Also, I have found that using the softmax function within each tree, the softmax values of all the leaf values add to 1.

Beyond this, I am generally quite confused about how the leaf values of all these trees amount to the decision of which class XGBoost chooses. My questions are

  1. How do the trees through the boosting iterations affect the output? For example, how does booster[0] and booster[6] (which represent the first and second boosting iterations for my class 0), influence the final output or the final probability for class 0?

  2. What is the math that goes behind going from the leaf values of all the trees to the decision of which class XGBoost chooses?

If answering by demonstration helps, I have provided the dumped txt file below, along with a sample input and output, both with multi:softprob and multi:softmax as objectives.

dump.raw.txt:

booster[0]:
0:[f0<0.5] yes=1,no=2,missing=1
    1:[f8<19.5299988] yes=3,no=4,missing=3
        3:leaf=0.244897947
        4:leaf=-0.042857144
    2:leaf=-0.0595400333
booster[1]:
0:[f2<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0594852231
    2:[f8<0.389999986] yes=3,no=4,missing=3
        3:leaf=0.272727251
        4:[f9<0.607749999] yes=5,no=6,missing=5
            5:[f9<0.290250003] yes=7,no=8,missing=7
                7:[f8<6.75] yes=11,no=12,missing=11
                    11:leaf=0.0157894716
                    12:leaf=-0.0348837189
                8:leaf=0.11249999
            6:[f8<12.6100006] yes=9,no=10,missing=9
                9:leaf=-0.0483870953
                10:[f8<15.1700001] yes=13,no=14,missing=13
                    13:leaf=0.0157894716
                    14:leaf=-0.0348837189
booster[2]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0595029891
    2:[f8<0.439999998] yes=3,no=4,missing=3
        3:[f5<0.5] yes=5,no=6,missing=5
            5:leaf=-0.042857144
            6:leaf=0.226027399
        4:[f9<-0.606250048] yes=7,no=8,missing=7
            7:leaf=0.0157894716
            8:leaf=-0.0545454584
booster[3]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0595029891
    2:[f5<0.5] yes=3,no=4,missing=3
        3:[f8<19.6599998] yes=5,no=6,missing=5
            5:leaf=0.260869563
            6:leaf=-0.0452054814
        4:leaf=-0.0524475537
booster[4]:
0:[f9<-0.477999985] yes=1,no=2,missing=1
    1:[f9<-0.622750044] yes=3,no=4,missing=3
        3:leaf=-0.0557312258
        4:[f10<0] yes=7,no=8,missing=7
            7:[f5<0.5] yes=11,no=12,missing=11
                11:leaf=0.0069767423
                12:leaf=0.0631578937
            8:leaf=-0.0483870953
    2:[f8<0.400000006] yes=5,no=6,missing=5
        5:leaf=-0.0563139915
        6:[f10<0] yes=9,no=10,missing=9
            9:[f8<19.5200005] yes=13,no=14,missing=13
                13:[f2<0.5] yes=17,no=18,missing=17
                    17:[f9<1.14275002] yes=23,no=24,missing=23
                        23:[f8<15.2000008] yes=27,no=28,missing=27
                            27:leaf=-0.0483870953
                            28:leaf=0.0157894716
                        24:leaf=0.0631578937
                    18:leaf=0.226829246
                14:leaf=0.293398529
            10:[f9<0.492500007] yes=15,no=16,missing=15
                15:[f8<17.2700005] yes=19,no=20,missing=19
                    19:leaf=0.152054787
                    20:leaf=-0.0570247956
                16:[f8<13.4099998] yes=21,no=22,missing=21
                    21:[f2<0.5] yes=25,no=26,missing=25
                        25:leaf=-0.0348837189
                        26:leaf=0.132558137
                    22:leaf=0.275871307
booster[5]:
0:[f9<-0.181999996] yes=1,no=2,missing=1
    1:[f10<0] yes=3,no=4,missing=3
        3:[f9<-0.49150002] yes=7,no=8,missing=7
            7:[f4<0.5] yes=13,no=14,missing=13
                13:leaf=0.0157894716
                14:leaf=0.226829246
            8:leaf=-0.0529411733
        4:[f8<12.9099998] yes=9,no=10,missing=9
            9:leaf=-0.0396226421
            10:leaf=0.285522789
    2:[f9<0.490750015] yes=5,no=6,missing=5
        5:[f10<0] yes=11,no=12,missing=11
            11:leaf=-0.0577405877
            12:[f8<17.2800007] yes=15,no=16,missing=15
                15:leaf=-0.0521739125
                16:[f2<0.5] yes=17,no=18,missing=17
                    17:leaf=0.274038434
                    18:leaf=0.0631578937
        6:leaf=-0.0589545034
booster[6]:
0:[f0<0.5] yes=1,no=2,missing=1
    1:[f8<19.5299988] yes=3,no=4,missing=3
        3:leaf=0.200149015
        4:leaf=-0.0419149213
    2:leaf=-0.0587796457
booster[7]:
0:[f2<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587093942
    2:[f8<0.389999986] yes=3,no=4,missing=3
        3:leaf=0.212223038
        4:[f9<0.607749999] yes=5,no=6,missing=5
            5:[f9<0.290250003] yes=7,no=8,missing=7
                7:[f8<6.75] yes=11,no=12,missing=11
                    11:leaf=0.0150387408
                    12:leaf=-0.0345491134
                8:leaf=0.102861121
            6:[f10<0] yes=9,no=10,missing=9
                9:leaf=-0.047783535
                10:[f9<0.93175] yes=13,no=14,missing=13
                    13:leaf=0.0160113405
                    14:leaf=-0.0342122875
booster[8]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587323084
    2:[f8<0.439999998] yes=3,no=4,missing=3
        3:[f5<0.5] yes=5,no=6,missing=5
            5:leaf=-0.0419248194
            6:leaf=0.187167063
        4:[f9<-0.606250048] yes=7,no=8,missing=7
            7:leaf=0.0154749081
            8:leaf=-0.0537380874
booster[9]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587323084
    2:[f5<0.5] yes=3,no=4,missing=3
        3:[f8<19.6599998] yes=5,no=6,missing=5
            5:leaf=0.207475975
            6:leaf=-0.0443004556
        4:leaf=-0.0517353415
booster[10]:
0:[f9<-0.477999985] yes=1,no=2,missing=1
    1:[f9<-0.622750044] yes=3,no=4,missing=3
        3:leaf=-0.0549092069
        4:[f10<0] yes=7,no=8,missing=7
            7:[f8<19.9899998] yes=11,no=12,missing=11
                11:leaf=0.0621421933
                12:leaf=0.00554796588
            8:leaf=-0.0474151336
    2:[f8<0.400000006] yes=5,no=6,missing=5
        5:leaf=-0.0555005781
        6:[f0<0.5] yes=9,no=10,missing=9
            9:leaf=-0.0508832447
            10:[f10<0] yes=13,no=14,missing=13
                13:[f3<0.5] yes=15,no=16,missing=15
                    15:leaf=0.220791802
                    16:[f9<0.988499999] yes=19,no=20,missing=19
                        19:leaf=-0.0421211571
                        20:leaf=0.059088923
                14:[f9<0.492500007] yes=17,no=18,missing=17
                    17:[f8<17.2700005] yes=21,no=22,missing=21
                        21:leaf=0.162014976
                        22:leaf=-0.0559271388
                    18:[f3<0.5] yes=23,no=24,missing=23
                        23:leaf=0.217694834
                        24:leaf=0.0335121229
booster[11]:
0:[f9<-0.181999996] yes=1,no=2,missing=1
    1:[f8<19.3400002] yes=3,no=4,missing=3
        3:leaf=-0.0464246981
        4:[f10<0] yes=7,no=8,missing=7
            7:[f9<-0.49150002] yes=11,no=12,missing=11
                11:leaf=0.178972095
                12:leaf=-0.0509003103
            8:leaf=0.218449697
    2:[f9<0.490750015] yes=5,no=6,missing=5
        5:[f10<0] yes=9,no=10,missing=9
            9:leaf=-0.0568957441
            10:[f8<17.2800007] yes=13,no=14,missing=13
                13:leaf=-0.0513576232
                14:[f2<0.5] yes=15,no=16,missing=15
                    15:leaf=0.212948546
                    16:leaf=0.0586818419
        6:leaf=-0.0581783429

Sample input, with the expected label: [0, 1, 0, 0, 1, 0, 1, 20, 16.8799, 0.587, 0.5], label: 0
multi:softmax output: [0]
multi:softprob output (if it helps): [[0.24506968 0.13953298 0.13952732 0.13952732 0.19666144 0.13968122]]

I know that this is a loaded question, and I hope I explained it clearly. Any help would be much appreciated. Thanks in advance!


Solution

    1. The trees build on their previous iterations for each class (hence boosting!). In your example, booster[0] and booster[6] both contribute to providing the numerator of the softmax probability for class 0.

    More generally, booster[i] and booster[i+6] contribute to providing numerator of the softmax probability for class i. If you increase the number of iterations from 2, you have booster[i], booster[i+6], ... booster[i+6n] all contributing to class i for n-1 iterations.

    1. We can demonstrate this using your example:

    Given your input and your dumped txt file, we can find the leaf values for each booster:

    Booster 0: 0.24489
    Booster 1: -0.0594
    Booster 2: -0.0595
    Booster 3: -0.0595
    Booster 4: 0.27587
    Booster 5: -0.0589
    Booster 6: 0.2
    Booster 7: -0.0587
    Booster 8: -0.0587
    Booster 9: -0.0587
    Booster 10: -0.0508
    Booster 11: -0.0582
    

    Now we just need to plug into the softmax formula to arrive at the probabilities for each of the five classes under softprob.

    Z_0 = e^{0.24489+0.2} = 1.5603
    Z_1 = e^{-0.0594-0.0587} = 0.8886
    Z_2 = e^{-0.0595-0.0587} = 0.8885
    Z_3 = e^{-0.0595-0.0587} = 0.8885
    Z_4 = e^{0.2758-0.0508} = 1.2523
    Z_5 = e^{-0.0589-0.0582} = 0.8895
    

    Summing these gives us the denominator of the softmax probability: 6.3677

    As such, we can compute softprob for each class,

    P(output=0) = 1.5603/6.3677 = 0.2450
    P(output=1) = 0.8886/6.3677 = 0.1395
    P(output=2) = 0.8885/6.3677 = 0.1395
    P(output=3) = 0.8885/6.3677 = 0.1395
    P(output=4) = 1.2523/6.3677 = 0.1967
    P(output=5) = 0.8895/6.3677 = 0.1397
    

    Picking the class with the highest probability (class 0) would yield your predicted softmax output.