Search code examples
pythonpandasmachine-learningscikit-learnmulticlass-classification

How to stratify the training and testing data in Scikit-Learn?


I am trying to implement Classification algorithm for Iris Dataset (Downloaded from Kaggle). In the Species column the classes (Iris-setosa, Iris-versicolor , Iris-virginica) are in sorted order. How can I stratify the train and test data using Scikit-Learn?


Solution

  • If you want to shuffle and split your data with 0.3 test ratio, you can use

    sklearn.model_selection import train_test_split
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)
    

    where X is your data, y is corresponding labels, test_size is the percentage of the data that should be held over for testing, shuffle=True shuffles the data before splitting

    In order to make sure that the data is equally splitted according to a column, you can give it to the stratify parameter.

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, 
                                                        shuffle=True, 
                                    stratify = X['YOUR_COLUMN_LABEL'])