I have a DataFrame c
containing a columns a
.
import numpy as np
a = np.random.randint(0,10, size=(100))
c = pd.DataFrame(a, columns=['a'])
I want to make random grouping of the rows of c
such that there are 5 rows within each group and that there are 1 row in each group with a < 3
so for example:
[1,2,3,2,10] <-- good group
[1,1,3,4,6] <-- good group
[2,4,7,3,7] <-- bad group
And if I ran out of rows to meet this criteria (for example I ran out of rows with a < 1
) then ignore the rest of the dataframe
Currently I do this by creating a new column group_id
and separate c
by condition then iteratively sample from them until i ran out of candidates:
c['group_id'] = None
c_w_small_a = c[c.a < 3].copy()
c_w_large_a = c[c.a >= 3].copy()
group_id = 0
while len(c_w_small_a) >= 1 and len(c_w_large_a) >= 4:
c.loc[c_w_small_a.sample(1, replace=False).index, 'group_id'] = group_id
c.loc[c_w_large_a.sample(4, replace=False).index, 'group_id'] = group_id
group_id += 1
c = c[c.group_id.apply(lambda x,x is not None)] # filter rows without id
c_groups = c.groupby('group_id')
The problem with this approach is that I can't generalize this approach with a more complex condition where the subsets overlap each other. such as
at most 2 rows with
a > 2
and at least 1 rows with 'a == 3'.
I don't know how to code it in such way to maximize the number of groups I can get with this grouping. For example if a ==3 is very limited then I don't want a> 2 to choose 3 even though that satisfies it's condition.
I am not sure, but I think that the problem you are describing is NP-complete and for this purpose I suggest that you think of a heuristic to find a satisfying solution. for this purpose you can write a greedy heuristic that will look like this :
def is_satisfying(group):
... if (np.sum(group > 2) > 2) or (np.sum(group == 3) < 1):
... return False
... else:
... return True
and then to construct a group, you can write something like:
group = []
while len(group) != 4 :
... np.append(group, df['a'].sample(n=1))
... if not is_satisfying(group):
... group = group[:-1]
and in order to mark elements that have already been added to groups, you use some data structure that will enable you to filter the dataframe before sampling