Search code examples
pythonpandasdataframerandomsampling

Random stratified sampling in pandas


I have created a pandas dataframe as follows:

import pandas as pd
import numpy as np

ds = {'col1' : [1,1,1,1,1,1,1,2,2,2,2,3,3,3,3,3,4,4,4,4,4,4,4,4,4],
      'col2' : [12,3,4,5,4,3,2,3,4,6,7,8,3,3,65,4,3,2,32,1,2,3,4,5,32],
      }

df = pd.DataFrame(data=ds)

The dataframe looks as follows:

print(df)

    col1  col2
0      1    12
1      1     3
2      1     4
3      1     5
4      1     4
5      1     3
6      1     2
7      2     3
8      2     4
9      2     6
10     2     7
11     3     8
12     3     3
13     3     3
14     3    65
15     3     4
16     4     3
17     4     2
18     4    32
19     4     1
20     4     2
21     4     3
22     4     4
23     4     5
24     4    32

Based on the values of column col1, I need to extract:

  • 3 random records where col1 == 1
  • 2 random records such that col1 = 2
  • 2 random records such that col1 = 3
  • 3 random records such that col1 = 4

Can anyone help me please?


Solution

  • I would shuffle the whole input with sample(frac=1), then compute a groupby.cumcount to select the first N samples per group (with map and boolean indexing) where N is defined in a dictionary:

    # {col1: number of samples}
    n = {1: 3, 2: 2, 3: 2, 4: 3}
    
    out = df[df[['col1']].sample(frac=1)
                         .groupby('col1').cumcount()
                         .lt(df['col1'].map(n))]
    

    Shorter code, but probably less efficient, using a custom groupby.apply with a different sample for each group:

    n = {1: 3, 2: 2, 3: 2, 4: 3}
    
    out = (df.groupby('col1', group_keys=False)
             .apply(lambda g: g.sample(n=n[g.name]))
          )
    

    Example output:

        col1  col2
    0      1    12
    3      1     5
    4      1     4
    7      2     3
    8      2     4
    11     3     8
    13     3     3
    17     4     2
    18     4    32
    24     4    32