Search code examples
pythonrandomscikit-learnsamplingresampling

How to separate a dataset into more than 2 random samples


If I have a dataset with say 1000 rows, what is the best way to separate the dataset into say 5 random samples (ie. each sample will have 200 rows).

I know there are functions like model_selection.train_test_split() and utils.resample() but these functions only separate the dataset into 2 samples.

Do I first need to generate a list of random numbers, in this case a list of 1000 random numbers (say from 1 to 1000), and then take the index in the dataset corresponding to having numbers 1 to 200 as the first random sample, 201 to 400 as the second random sample, 401 to 600 as the third random sample etc.

Or is there a function in Python somewhere that I could use (to make my life easier)?


Solution

  • You can use Kfold from scikit-learn to generate the indices you're asking for. If you take the smaller fold (the 20 %) then you'll have the 5 slices of data you need:

    from sklearn.model_selection import KFold
    import numpy as np
    
    data = range(10)
    kf = KFold(n_splits=5, shuffle=True)
    for i in kf.split(data):
        print(i[1])
    

    Here are your pseudo-random, non-overlapping indices for you to select the relevant portions of your data/labels

    [4 9]
    [1 3]
    [6 7]
    [0 2]
    [5 8]
    

    If you wanted Stratified sampling, then you will have to use StratifiedKFold in a similar way.

    If you want it as a function, I would probably create it as a generator:

    def segment_data(data, labels, no_segments=5, shuffle=True):
        kf = KFold(n_splits=no_segments, shuffle=shuffle)
        for _, indices in kf.split(range(data.shape[0])):
            yield data[indices], labels[indices]
    
    my_labels = ["L1", "L2", "L3"]
    all_labels = np.random.choice(my_labels, size=100, replace=True, p=(0.1, 0.45, 0.45)
    all_data = np.random.uniform(size=100)
    
    for data, labels in segment_data(all_data, all_labels):
        print(data)
        print(labels)