Search code examples
pythonpandaspandas-groupbyshuffle

Sort a dataframe by a `label` column, shuffle per each `label`, preserve order per `group`


Given the following dataframe:

df = pd.DataFrame(data={'value': ['all', 'moon', 'less', 'cat', 'pen' , 'dark', 'pile'],
                        'label': [0, 1, 1, 0, 1, 0, 0],
                        'group': ['A', 'B', 'B', 'B', 'A', 'B', 'A']})

output:

    value      label  group
0   'all'      0      'A'
1   'moon'     1      'B'
2   'less'     1      'B'
3   'cat'      0      'B'
4   'pen'      1      'A'
5   'dark'     0      'B'
6   'pile'     0      'A'

I like to generate a new dataframe with the following conditions:

  1. Rows are sorted by label
  2. Per each label, rows are shuffled
  3. But maintaining order based on the value

So for example, here is a possible outcome:

    value      label  group
0   'all'      0      'A'
3   'cat'      0      'B'
5   'dark'     0      'B'
6   'pile'     0      'A'
2   'less'     1      'B'
4   'pen'      1      'A'
1   'moon'     1      'B'

So regarding condition 3, 'pile' comes after 'all' both with the same label and from the same group. Any other shuffle and sort, shouldn't allow 'pile' to come before 'all'.

Or another one with a different shuffle:

    value      label  group
3   'cat'      0      'B'
0   'all'      0      'A'
6   'pile'     0      'A'
5   'dark'     0      'B'
4   'pen'      1      'A'
2   'less'     1      'B'
1   'moon'     1      'B'

Any thoughts on a clean way to achieve this?


Solution

  • This is actually quite complex to achieve.

    First shuffle completely the dataframe using sample(frac=1):

    # np.random.seed(0) # for reproducibility
    df2 = df.sample(frac=1).sort_values(by='label', ignore_index=True)
    

    output:

        value  label group
    0  'pile'      0   'A'
    1   'cat'      0   'B'
    2   'all'      0   'A'
    3  'dark'      0   'B'
    4  'less'      1   'B'
    5  'moon'      1   'B'
    6   'pen'      1   'A'
    

    Then sort the values by label and determine the sorting order per group:

    idx = (df2.reset_index()                  # save index as column
           .sort_values(by='value')           # sort values
           .groupby(['label', 'group'])['index']  # reorder the index per value
           .transform(sorted).sort_values()       # using sorted
           .index
          )
    # Int64Index([2, 1, 0, 3, 4, 5, 6], dtype='int64')
    

    Finally use this to reindex your df2:

    df2.loc[idx]
    

    output:

        value  label group
    2   'all'      0   'A'
    1   'cat'      0   'B'
    0  'pile'      0   'A'
    3  'dark'      0   'B'
    4  'less'      1   'B'
    5  'moon'      1   'B'
    6   'pen'      1   'A'