Search code examples
pythonmachine-learningscikit-learnmulticlass-classification

How to get top n prediction labels from classifier.predict_proba() output?


I am trying to get top n predicted labels of a text based multi-label classification problem as a list for a particular record.

I have tried the following...

y_pred_proba = classifier.predict_proba(X_test) 
n = 5
top_n_pred = np.argsort(y_pred_proba, axis=1)[:,-n :]
class_labels = classifier.classes_

Please help me to combine top_n_pred and class_labels to get top n labels as a list for each row of X_test?

If there is any shortcut to achieve the same, that is also welcome.


Solution

  • I would first try class_labels[top_n_pred] which might fail but then just do it with an iterator.