Search code examples

How to interpret predictions from LightGBM

I am trying to obtain predictions from my LightGBM model, simple min example is provided in the first answer here. When I run the provided code from there (which I have copied below) and run model.predict, I would expect to get the predictions for the binary target, 0 or 1 but I get a continuous variable instead:

import numpy as np
import pandas as pd
import lightgbm

df = pd.DataFrame({
    "query_id":[i for i in range(100) for j in range(10)],
    "relevance":list(np.random.permutation([0,0,0,0,0, 0,0,0,1,1]))*100

train_df = df[:800]  # first 80%
validation_df = df[800:]  # remaining 20%

qids_train = train_df.groupby("query_id")["query_id"].count().to_numpy()
X_train = train_df.drop(["query_id", "relevance"], axis=1)
y_train = train_df["relevance"]

qids_validation = validation_df.groupby("query_id")["query_id"].count().to_numpy()
X_validation = validation_df.drop(["query_id", "relevance"], axis=1)
y_validation = validation_df["relevance"]

model = lightgbm.LGBMRanker(
    eval_set=[(X_validation, y_validation)],



  • I would expect to get the predictions for the binary target, 0 or 1 but I get a continuous variable instead

    You are using LGBMRanker so you are trying to rank your data by query group. You are not trying to classify if the query is relevant (1) or not (0).

    To get the rank of each query id, you can use:

    # Avoid SettingWithCopyWarning
    train_df = df[:800].copy()  # first 80%
    validation_df = df[800:].copy()  # remaining 20%
    # Predict
    y_pred = model.predict(X_train)
    # Rank
    train_df['rank'] = (train_df.assign(pred=y_pred).groupby('query_id')['pred']
                                .rank(method='dense', ascending=False).astype(int))


    >>> train_df.head(20)
        query_id      var1      var2      var3  relevance  rank
    0          0  0.310088  0.855698  0.812061          0     4
    1          0  0.450598  0.178649  0.734282          1     1
    2          0  0.767785  0.948723  0.554276          0     7
    3          0  0.832502  0.768063  0.659054          0    10
    4          0  0.184687  0.708391  0.205448          0     6
    5          0  0.633676  0.886995  0.397744          0     9
    6          0  0.323786  0.921241  0.228752          0     3
    7          0  0.381342  0.712662  0.142029          0     8
    8          0  0.993355  0.182123  0.104416          1     2
    9          0  0.118358  0.032669  0.298469          0     5
    10         1  0.508594  0.362492  0.600220          0     7
    11         1  0.156516  0.624849  0.737191          1     1
    12         1  0.420040  0.010797  0.522225          0     4
    13         1  0.508450  0.691576  0.395909          0    10
    14         1  0.840053  0.364863  0.439742          0     9
    15         1  0.941446  0.640084  0.614170          0     8
    16         1  0.804287  0.577919  0.590988          0     5
    17         1  0.090188  0.166563  0.364991          0     3
    18         1  0.124671  0.326929  0.577342          1     2
    19         1  0.037713  0.383510  0.500487          0     6