Search code examples
python-3.xvalidationscikit-learncross-validationtrain-test-split

how to use an explicit validation set with predefined split fold?


I have explicit train, test and validation sets as 2d arrays:

X_train.shape
(1400, 38785)
X_val.shape
(200, 38785)
X_test.shape
(400, 38785)

I am tuning the alpha parameter and need advice about how I can use the predefined validation set in it:


from sklearn.naive_bayes import MultinomialNB
from sklearn.model_selection import GridSearchCV, PredefinedSplit

nb = MultinomialNB()
nb.fit(X_train, y_train)

params = {'alpha': [0.1, 1, 3, 5, 10,12,14]}
# how to use on my validation set?
# ps = PredefinedSplit(test_fold=?)

gs = GridSearchCV(nb, param_grid=params, cv = ps,  return_train_score=True, scoring='f1')

gs.fit(X_train, y_train)

My results are as following so far.

# on my validation set, alpha = 5
gs.fit(X_val, y_val)
print('Grid best parameter', gs.best_params_)
Grid best parameter:  {'alpha': 5}

# on my training set, alpha = 10
Grid best parameter:  {'alpha': 10}

I have read the following questions and documentation yet I am not sure how to use PredefinedSplit() in my case. Thank you.

Order between using validation, training and test sets

https://scikit-learn.org/stable/modules/cross_validation.html#predefined-fold-splits-validation-sets


Solution

  • You can achieve your desired outcome by merging X_train and X_val, and passing PredefinedSplit a list of labels, with -1 indicating training data and 1 indicating validation data. IE,

    
    X = np.concatenate((X_train, X_val))
    y = np.concatenate((y_train, y_val))
    ps = PredefinedSplit(np.concatenate((np.zeros(len(x_train) - 1, np.ones(len(x_val))))
    
    gs = GridSearchCV(nb, param_grid=params, cv = ps,  return_train_score=True, scoring='f1')
    
    gs.fit(X, y)  # not X_train, y_train
    

    However, unless there is very a good reason for you holding out a separate validation set, you will likely have less overfitting if you use k-fold cross validation for your hyperparameter tuning rather than using a dedicated validation set.