Data:
df1 = pd.DataFrame(np.random.randint(0,1000,size=(100, 4)),
columns=list('ABCD'))
df1["cat1"] = np.random.choice(['a', 'b'], len(df1))
df1["cat2"] = np.random.choice(['32782', '35871', '35865'], len(df1))
df1["cat3"] = np.random.choice(['pq', 'xy', 'ab', 'hq'], len(df1))
I want to take sample of this dataset i.e., 200 rows by having max possible equal number of row from each category of 3 columns
We can validate like,
assert len(sample['a']) == len(sample['b'])
assert len(sample['32782']) == len(sample['35871'] == len(sample['35865']))
assert len(sample['pq']) == len(sample['xy'] == len(sample['ab'] == len(sample['hq'])))
In my there are 100M rows and I want to take 200K rows.
I tried to use sample = df1.sample(n=200000, replace=True, random_state=123)
to take the 200K rows, but not sure how to use randome sample i.e., df1.sample
to get best possible equal number of rows from each group?
Having exact same rows is not a strict condition, even +/- 5% error is also fine.
Update:
replace=True
is used to get repeat rows, if n
is small.
You can use pd.groupby()
and then apply sample
:
n = 1
df1.groupby(['cat1', 'cat2', 'cat3']).apply(lambda s: s.sample(n))
You can use .reset_index(drop=True)
to drop the index if you wish.
Only n=1
works with my dummy example as there was a combination of categories that only existed once. If the dataset is larger, possibly larger values for n
are acceptable.
To find the maximum value for n
, you need to groupby the three categories and count the number of occurances (also include the zero occurances). Then take the minimum which is your maximum value for n
:
from itertools import product
combs = pd.DataFrame(list(product(df1['cat1'].unique(), df1['cat2'].unique(), df1['cat3'].unique())),
columns=['cat1', 'cat2', 'cat3'])
groupby = df1.groupby(['cat1', 'cat2', 'cat3']).size().reset_index()
result = groupby.merge(combs, how = 'right').fillna(0)
n_max = int(result[0].min())
You can verify that this is indeed the maximum value by plugging in n = n_max + 1
in the code on top, as this will give an error.
Output:
A B C D cat1 cat2 cat3
cat1 cat2 cat3
a 32782 ab 60 369 281 970 277 a 32782 ab
hq 8 94 933 560 622 a 32782 hq
pq 65 369 356 120 533 a 32782 pq
xy 3 227 267 664 161 a 32782 xy
35865 ab 45 991 929 664 400 a 35865 ab
hq 10 52 337 303 804 a 35865 hq
pq 2 639 557 828 90 a 35865 pq
xy 57 823 882 11 574 a 35865 xy
35871 ab 98 900 331 527 966 a 35871 ab
hq 70 132 394 235 177 a 35871 hq
pq 9 660 411 342 752 a 35871 pq
xy 79 617 780 555 649 a 35871 xy
b 32782 ab 35 820 962 374 180 b 32782 ab
hq 22 813 53 919 840 b 32782 hq
pq 18 682 449 660 226 b 32782 pq
xy 73 471 578 267 29 b 32782 xy
35865 ab 77 301 953 121 525 b 35865 ab
hq 43 700 312 947 339 b 35865 hq
pq 59 307 259 287 749 b 35865 pq
xy 61 552 164 129 53 b 35865 xy
35871 ab 68 113 678 805 226 b 35871 ab
hq 88 533 732 359 891 b 35871 hq
pq 74 416 279 407 387 b 35871 pq
xy 7 848 776 779 719 b 35871 xy