Search code examples
scikit-learntrain-test-split

stratify argument in train_test_split vs StratifiedShuffleSplit


What is the difference between using the stratify argument in train_test_split function of sklearn, and the StratifiedShuffleSplit function? Don't they do the same thing?


Solution

  • These two modules perform different operations.

    train_test_split, as its name clearly implies, is used for splitting the data in a single training & single test subset, and the stratify argument permits doing this in a stratified way.

    StratifiedShuffleSplit, on the other hand, provides splits for cross-validation; from the docs:

    Stratified ShuffleSplit cross-validator

    Provides train/test indices to split data in train/test sets.

    Notice the plural sets (emphasis mine).

    So, StratifiedShuffleSplit is there to be used instead of KFold when we want to ensure the CV splits are stratified, and not to replace train_test_split.