I am trying to obtain predictions from my LightGBM model, simple min example is provided in the first answer here. When I run the provided code from there (which I have copied below) and run model.predict, I would expect to get the predictions for the binary target, 0 or 1 but I get a continuous variable instead:
import numpy as np
import pandas as pd
import lightgbm
df = pd.DataFrame({
"query_id":[i for i in range(100) for j in range(10)],
"var1":np.random.random(size=(1000,)),
"var2":np.random.random(size=(1000,)),
"var3":np.random.random(size=(1000,)),
"relevance":list(np.random.permutation([0,0,0,0,0, 0,0,0,1,1]))*100
})
train_df = df[:800] # first 80%
validation_df = df[800:] # remaining 20%
qids_train = train_df.groupby("query_id")["query_id"].count().to_numpy()
X_train = train_df.drop(["query_id", "relevance"], axis=1)
y_train = train_df["relevance"]
qids_validation = validation_df.groupby("query_id")["query_id"].count().to_numpy()
X_validation = validation_df.drop(["query_id", "relevance"], axis=1)
y_validation = validation_df["relevance"]
model = lightgbm.LGBMRanker(
objective="lambdarank",
metric="ndcg",
)
model.fit(
X=X_train,
y=y_train,
group=qids_train,
eval_set=[(X_validation, y_validation)],
eval_group=[qids_validation],
eval_at=10,
verbose=10,
)
model.predict(X_train)
I would expect to get the predictions for the binary target, 0 or 1 but I get a continuous variable instead
You are using LGBMRanker
so you are trying to rank your data by query group. You are not trying to classify if the query is relevant (1) or not (0).
To get the rank of each query id, you can use:
# Avoid SettingWithCopyWarning
train_df = df[:800].copy() # first 80%
validation_df = df[800:].copy() # remaining 20%
# Predict
y_pred = model.predict(X_train)
# Rank
train_df['rank'] = (train_df.assign(pred=y_pred).groupby('query_id')['pred']
.rank(method='dense', ascending=False).astype(int))
Output:
>>> train_df.head(20)
query_id var1 var2 var3 relevance rank
0 0 0.310088 0.855698 0.812061 0 4
1 0 0.450598 0.178649 0.734282 1 1
2 0 0.767785 0.948723 0.554276 0 7
3 0 0.832502 0.768063 0.659054 0 10
4 0 0.184687 0.708391 0.205448 0 6
5 0 0.633676 0.886995 0.397744 0 9
6 0 0.323786 0.921241 0.228752 0 3
7 0 0.381342 0.712662 0.142029 0 8
8 0 0.993355 0.182123 0.104416 1 2
9 0 0.118358 0.032669 0.298469 0 5
10 1 0.508594 0.362492 0.600220 0 7
11 1 0.156516 0.624849 0.737191 1 1
12 1 0.420040 0.010797 0.522225 0 4
13 1 0.508450 0.691576 0.395909 0 10
14 1 0.840053 0.364863 0.439742 0 9
15 1 0.941446 0.640084 0.614170 0 8
16 1 0.804287 0.577919 0.590988 0 5
17 1 0.090188 0.166563 0.364991 0 3
18 1 0.124671 0.326929 0.577342 1 2
19 1 0.037713 0.383510 0.500487 0 6