Search code examples
pandasdataframescikit-learncross-validation

Scikit-Learn cross validation function not allowing custom folds when indices are not sequential


Attempting to pass in custom cross validation folds to sklearn's cross validate function.

The cross validate function seems to be triggering an error because it's insisting on using position-based indexing, rather than label-based indexing. The indices I'm passing in my cv_folds argument are consistent with the original dataframe's indices. The reason this is relevant is because I want to use a hash function value to select subsets for my train-test split, as well as my cv folds. I get the following error: IndexError: indices are out-of-bounds

df2 = pd.DataFrame(np.random.rand(8, 3), columns=['feature_1', 'feature_2', 'feature_3'])
train_index_list = [0,1,2,5,6,7]
test_index_list = [3,4]
X_train = df2.loc[train_index_list].drop(columns='feature_3').copy()
y_train = df2.loc[train_index_list]['feature_3'].copy()
# 2-fold cross validation
cv_folds = [ ([0,1,2,],[5,6,7]), ([5,6,7], [0,1,2])]
cv_output = cross_validate(model, X_train, y_train,  scoring=['neg_mean_squared_error'], cv=cv_folds) 

This triggers an error. But what puzzles me is that the following lines run just fine

X_train.loc[train_index_list]
y_train.loc[train_index_list]

How do I resolve this so I can pass in my custom-defined cv folds into Scikit-Learn?


Solution

  • You can use a workaround by using a Index.get_indexer to convert labels to index positions:

    def cv_folds(df, labels):
        for i, j in labels:
            i = df.index.get_indexer(i)
            j = df.index.get_indexer(j)
            yield (i.tolist(), j.tolist())
    
    labels = [([0, 1, 2], [5, 6, 7]), ([5, 6, 7], [0, 1, 2])]
    cv = cv_folds(X_train, labels)
    cv_output = cross_validate(model, X_train, y_train, cv=cv,
                               scoring=['neg_mean_squared_error'])
    

    Test:

    >>> list(cv_folds(X_train, labels))
        [([0, 1, 2], [3, 4, 5]), ([3, 4, 5], [0, 1, 2])]  # <- positions
    #   [([0, 1, 2], [5, 6, 7]), ([5, 6, 7], [0, 1, 2])]  # <- labels