Search code examples
scikit-learncross-validation

scikit-learn cross_validate: reveal test set indices


In sklearn.model_selection.cross_validate , is there a way to output the samples / indices which were used as test set by the CV splitter for each fold?


Solution

  • There's an option to specify the cross-validation generator, using cv option :

    cv int, cross-validation generator or an iterable, default=None Determines the cross-validation splitting strategy. Possible inputs for cv are:

    None, to use the default 5-fold cross validation,

    int, to specify the number of folds in a (Stratified)KFold,

    CV splitter,

    An iterable yielding (train, test) splits as arrays of indices.

    For int/None inputs, if the estimator is a classifier and y is either binary or multiclass, StratifiedKFold is used. In all other cases, KFold is used. These splitters are instantiated with shuffle=False so the splits will be the same across calls.

    If you provide it as an input to cross_validate :

    from sklearn import datasets, linear_model
    from sklearn.model_selection import cross_validate
    from sklearn.model_selection import KFold
    from sklearn.svm import LinearSVC
    diabetes = datasets.load_diabetes()
    X = diabetes.data[:150]
    y = diabetes.target[:150]
    lasso = linear_model.Lasso()
    
    kf = KFold(5, random_state = 99, shuffle = True)
    cv_results = cross_validate(lasso, X, y, cv=kf)
    

    You can extract the index like this:

    idx = [test_index for train_index, test_index in kf.split(X)]
    

    Where the first in the list will be the test index for the 1st fold and so on..