Search code examples
pythonpandasmultilabel-classificationskmultilearn

Iterative split of multilabel classification dataset in pandas dataframe


I have dataset which contains text column with string values and multiple column with value 1 or 0 (classified or no). I want to use skmultilearn to split this data with even distribution, but I got this error:

KeyError: 'key of type tuple not found and not a MultiIndex'

And this is my code:

import pandas as pd
from skmultilearn.model_selection import iterative_train_test_split


y = pd.read_csv("dataset.csv")
x = y.pop("text")

x_train, x_test, y_train, y_test = iterative_train_test_split(x, y, test_size=0.1)

Solution

  • Here is what worked for me (this is 98/1/1 split):

    import os
    import pandas as pd
    from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
    
    
    def main():
        # load dataset
        y = pd.read_csv("dataset.csv")
        x = y.pop("text")
    
        # save tag names to reuse them later for creating pandas DataFrames
        tag_names = y.columns
    
        # Data has to be in ndarray format
        y = y.to_numpy()
        x = x.to_numpy()
    
        # split to train / test
        msss = MultilabelStratifiedShuffleSplit(n_splits=2, test_size=0.02, random_state=42)
        for train_index, test_index in msss.split(x, y):
            x_train, x_test_temp = x[train_index], x[test_index]
            y_train, y_test_temp = y[train_index], y[test_index]
    
        # make some memory space
        del x
        del y
    
        # split to test / validation
        msss = MultilabelStratifiedShuffleSplit(n_splits=2, test_size=0.5, random_state=42)
        for test_index, val_index in msss.split(x_test_temp, y_test_temp):
            x_test, x_val = x_test_temp[test_index], x_test_temp[val_index]
            y_test, y_val = y_test_temp[test_index], y_test_temp[val_index]
    
        # train dataset
        df_train = pd.DataFrame(data=y_train, columns=tag_names)
        df_train.insert(0, "snippet", x_train)
    
        # validation dataset
        df_val = pd.DataFrame(data=y_val, columns=tag_names)
        df_val.insert(0, "snippet", x_val)
    
        # test dataset
        df_test = pd.DataFrame(data=y_test, columns=tag_names)
        df_test.insert(0, "snippet", x_test)
    
    
    if __name__ == "__main__":
        main()