Search code examples
pythonclassificationcross-validationtraining-dataindices

How to get indices of instances during cross-validation


I am doing a binary classification. May I know how to extract the real indexes of the misclassified or classified instances of the training data frame while doing K fold cross-validation? I found no answer to this question here.

I got the values in folds as described here:

skf=StratifiedKFold(n_splits=10,random_state=111,shuffle=False)
cv_results = cross_val_score(model, X_train, y_train, cv=skf, scoring='roc_auc')
fold_pred = [pred[j] for i, j in skf.split(X_train,y_train)]
fold_pred

Is there any method to get index of misclassified (or classified ones)? So the output is a dataframe that only has misclassified(or classified) instances while doing cross validation.

Desired output: Missclassified instances in the dataframe with real indices.

     col1 col2 col3 col4  target
13    0    1    0    0    0
14    0    1    0    0    0
18    0    1    0    0    1
22    0    1    0    0    0

where input has 100 instances, 4 are misclassified (index number 13,14,18 and 22) while doing CV


Solution

  • From cross_val_predict you already have the predictions. It's a matter of subsetting your data frame where the predictions are not the same as your true label, for example:

    from sklearn.ensemble import RandomForestClassifier
    from sklearn.model_selection import cross_val_predict, StratifiedKFold 
    from sklearn.datasets import load_breast_cancer
    import pandas as pd
    
    data = load_breast_cancer()
    df = pd.DataFrame(data.data[:,:5],columns=data.feature_names[:5])
    df['label'] = data.target
    
    rfc = RandomForestClassifier()
    skf = StratifiedKFold(n_splits=10,random_state=111,shuffle=True)
    
    pred = cross_val_predict(rfc, df.iloc[:,:5], df['label'], cv=skf)
    
    df[df['label']!=pred]
     
         mean radius  mean texture  ...  mean smoothness  label
    3          11.42         20.38  ...          0.14250      0
    5          12.45         15.70  ...          0.12780      0
    9          12.46         24.04  ...          0.11860      0
    22         15.34         14.26  ...          0.10730      0
    31         11.84         18.70  ...          0.11090      0