This is probably a really stupid question, but why do the following give different results?
X == array([ 7.84682988e-01, 3.80109225e-17, 8.06386582e-01,
1.00000000e+00, 5.71428571e-01, 4.44189342e+00])
model.predict_proba(X)[1] # gives array([ 0.35483244, 0.64516756])
model.predict_proba(X[1]) # gives an error
model.predict_proba(list(X[1])) # gives array([[ 0.65059327, 0.34940673]])
Model
is a LGBMClassifier
from the lightgbm library.
Lets break it into simple steps to analyse:
1) model.predict_proba(X)[1]
This is equivalent to
probas = model.predict_proba(X)
probas[1]
So this first outputs the probabilities of all classes for all samples. So lets say your X contains 5 rows and 4 features, with two different classes.
So probas will be something like this:
Prob of class 0, prob of class 1
For sample1 [[0.1, 0.9],
For sample2 [0.8, 0.2],
For sample3 [0.85, 0.15],
For sample4 [0.4, 0.6],
For sample5 [0.01, 0.99]]
probas[1]
will just output the probabilities for second column of your probas
output, ie. probability of class 1.
Output [0.9, 0.2, 0.15, 0.6, 0.99]
Other two lines of code depend on the implementation and version of how to handle single dimension array. For eg. scikit v18 only shows a warning for it and considers it as a single row. But v19 (master branch) throws an error.
EDIT: Updated for LGBMClassifier
2) model.predict_proba(X[1])
This is equivalent to:
X_new = X[1]
model.predict_proba(X_new)
Here you are selecting only the second row which results in a shape [n_features, ]
. But LGBMClassifier require 2-d data to be of shape [n_samples, n_features]
. This can be a possible source of error as mentioned above. You can reshape the given array to have 1 in place of n_samples:
model.predict_proba(X[1].reshape(1, -1))
# Will work correctly
3) model.predict_proba(list(X[1]))
This can be broken into:
X_new = list(X[1])
model.predict_proba(X_new)
This is also mostly same as 2nd, just that the X_new
is now a list instead of numpy array, and automatically handled as a single row (same as X[1].reshape(1, -1)
in 2nd case), instead of throwing an error.
So considering the example above, the output will be only,
For sample2 [0.8, 0.2],