Search code examples
pythonscikit-learncross-validation

How to get the train and test data for each fold in kfold cross validation?


How can I access the train and test data for each fold in cross validation? I would like to save these in .csv files. I tried using the split function which generates the indices but it returns a generator object, not the indices.

from sklearn.model_selection import StratifiedKFold, KFold
import numpy as np
X, y = np.ones((50, 1)), np.hstack(([0] * 45, [1] * 5))
skf = StratifiedKFold(n_splits=3)
x = skf.split(X, y, groups)
x

Output:
<generator object _BaseKFold.split at 0x7ff195979580>


Solution

  • StratifiedKFold returns a generator, therefore you it to iterate over it as follows:

    skf = StratifiedKFold(n_splits=3)
    for train_index, test_index in skf.split(X, y):
         X_train, X_test = X[train_index], X[test_index]
         y_train, y_test = y[train_index], y[test_index]