Search code examples
pythonpython-3.xpandasdataframecross-validation

How can I evenly split up a pandas.DataFrame into n-groups?


I need to perform n-fold (in my particular case, a 5-fold) cross validation on a dataset that I've stored in a pandas.DataFrame. My current way seems to rearrange the row labels;

spreadsheet1 = pd.ExcelFile("Testing dataset.xlsx") 
dataset = spreadsheet1.parse('Sheet1') 

data = 5 * [pd.DataFrame()]

i = 0
while(i  < len(dataset)):
    j = 0
    while(j < 5 and i < len(dataset)):
        data[j] = (data[j].append(dataset.iloc[i])).reset_index(drop = True)
        i += 1
        j += 1

How can I split my DataFrame efficiently/intelligently without tampering with the order of the columns?


Solution

  • Use np.array_split to break it up into a list of "evenly" sized DataFrames. You can shuffle too if you sample the full DataFrame

    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame(np.arange(24).reshape(-1,2), columns=['A', 'B'])
    N = 5    
    
    np.array_split(df, N)
    #np.array_split(df.sample(frac=1), N)  # Shuffle and split
    

    [   A  B
     0  0  1
     1  2  3
     2  4  5,
         A   B
     3   6   7
     4   8   9
     5  10  11,
         A   B
     6  12  13
     7  14  15,
         A   B
     8  16  17
     9  18  19,
          A   B
     10  20  21
     11  22  23]