Search code examples
pythonscikit-learnkeyword-argumentiterable-unpacking

Setting kwargs for train_test_split


I have a notebook that iterates over the same model, with increasing features. I'd like to simply fill out the train_test_split() with a dict of the relevant args, rather than filling it out each time. For my Random Forest model, for example, I've created a hyperparameter dict:

rf_params = {
    'class_weight':'balanced',
    'max_depth':2,
    'n_estimators':1000,
    'n_jobs':-1,
    'random_state':42
}

I'll unpack that with each random forest classifier: rf_clf = RandomForestClassifier(**rf_params). I'd like to do the same with train_test_split(), whose arguments will always be X, y, test_size=0.3, random_state=42, but upon examining the docs, it appears that the standard call for X and y is not a kwarg, but rather based on *arrays. How do I set that in a dict to unpack?

Something along the lines of

split_args = {
    '*arrays':['X','y'],
    'test_size':0.3,
    'random_state':42
 }

train_test_split(**split_args)

Solution

  • You can use *args for that:

    split_args = [X, y]
    split_kwargs = {
        'test_size':0.3,
        'random_state':42
    }
    
    train_test_split(*split_args, **split_kwargs)
    

    Since *arrays is positional only, it needs to be passed as an iterable instead of a dict.

    See: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html