Search code examples
pythonpandasrandomsample

How to get sample of data by having best possible equal number of rows from multiple columns?


Data:

df1 = pd.DataFrame(np.random.randint(0,1000,size=(100, 4)), 
                       columns=list('ABCD'))

df1["cat1"] = np.random.choice(['a', 'b'], len(df1))
df1["cat2"] = np.random.choice(['32782', '35871', '35865'], len(df1))
df1["cat3"] = np.random.choice(['pq', 'xy', 'ab', 'hq'], len(df1))

I want to take sample of this dataset i.e., 200 rows by having max possible equal number of row from each category of 3 columns

We can validate like,

assert len(sample['a']) == len(sample['b'])
assert len(sample['32782']) == len(sample['35871'] == len(sample['35865']))
assert len(sample['pq']) == len(sample['xy'] == len(sample['ab'] == len(sample['hq'])))

In my there are 100M rows and I want to take 200K rows.

I tried to use sample = df1.sample(n=200000, replace=True, random_state=123) to take the 200K rows, but not sure how to use randome sample i.e., df1.sample to get best possible equal number of rows from each group?

Having exact same rows is not a strict condition, even +/- 5% error is also fine.

Update:

replace=True is used to get repeat rows, if n is small.


Solution

  • You can use pd.groupby() and then apply sample:

    n = 1
    df1.groupby(['cat1', 'cat2', 'cat3']).apply(lambda s: s.sample(n))
    

    You can use .reset_index(drop=True) to drop the index if you wish.

    Only n=1 works with my dummy example as there was a combination of categories that only existed once. If the dataset is larger, possibly larger values for n are acceptable.

    To find the maximum value for n, you need to groupby the three categories and count the number of occurances (also include the zero occurances). Then take the minimum which is your maximum value for n:

    from itertools import product
    combs = pd.DataFrame(list(product(df1['cat1'].unique(), df1['cat2'].unique(), df1['cat3'].unique())),
                         columns=['cat1', 'cat2', 'cat3'])
    groupby = df1.groupby(['cat1', 'cat2', 'cat3']).size().reset_index()
    result = groupby.merge(combs, how = 'right').fillna(0)
    
    n_max = int(result[0].min())
    

    You can verify that this is indeed the maximum value by plugging in n = n_max + 1 in the code on top, as this will give an error.

    Output:

                          A    B    C    D cat1   cat2 cat3
    cat1 cat2  cat3                                        
    a    32782 ab   60  369  281  970  277    a  32782   ab
               hq   8    94  933  560  622    a  32782   hq
               pq   65  369  356  120  533    a  32782   pq
               xy   3   227  267  664  161    a  32782   xy
         35865 ab   45  991  929  664  400    a  35865   ab
               hq   10   52  337  303  804    a  35865   hq
               pq   2   639  557  828   90    a  35865   pq
               xy   57  823  882   11  574    a  35865   xy
         35871 ab   98  900  331  527  966    a  35871   ab
               hq   70  132  394  235  177    a  35871   hq
               pq   9   660  411  342  752    a  35871   pq
               xy   79  617  780  555  649    a  35871   xy
    b    32782 ab   35  820  962  374  180    b  32782   ab
               hq   22  813   53  919  840    b  32782   hq
               pq   18  682  449  660  226    b  32782   pq
               xy   73  471  578  267   29    b  32782   xy
         35865 ab   77  301  953  121  525    b  35865   ab
               hq   43  700  312  947  339    b  35865   hq
               pq   59  307  259  287  749    b  35865   pq
               xy   61  552  164  129   53    b  35865   xy
         35871 ab   68  113  678  805  226    b  35871   ab
               hq   88  533  732  359  891    b  35871   hq
               pq   74  416  279  407  387    b  35871   pq
               xy   7   848  776  779  719    b  35871   xy