Search code examples
pythonpandasgroup-byshuffle

Identical shuffling of pandas dataframe column after groupby


I have a dataframe with the following structure

ID  col2    col3
1   101001  a
1   101001  b
1   101001  c
1   101002  a
1   101002  b
1   101002  c
1   101003  a
1   101003  b
1   101003  c
1   101004  a
1   101004  b
1   101004  c
2   101001  a
2   101001  b
2   101001  d
2   101002  a
2   101002  b
2   101002  d
2   101003  a
2   101003  b
2   101003  d
2   101004  a
2   101004  b
2   101004  d
3   101001  b
3   101001  c
3   101001  d
3   101002  b
3   101002  c
3   101002  d
3   101003  b
3   101003  c
3   101003  d
3   101004  b
3   101004  c
3   101004  d

I need to group by column ID, reshuffle the entries in column col3 for each respective ID and write the results in a new column col4. In addition the shuffling should be identical for a given ID, independent of the entries in col2, e.g. for ID=1 the entry a in col3 should always be b after reshuffling:

ID  col2    col3    col4
1   101001  a   b
1   101001  b   a
1   101001  c   c
1   101002  a   b
1   101002  b   a
1   101002  c   c
1   101003  a   b
1   101003  b   a
1   101003  c   c
1   101004  a   b
1   101004  b   a
1   101004  c   c
2   101001  a   d
2   101001  b   a
2   101001  d   b
2   101002  a   d
2   101002  b   a
2   101002  d   b
2   101003  a   d
2   101003  b   a
2   101003  d   b
2   101004  a   d
2   101004  b   a
2   101004  d   b
3   101001  b   b
3   101001  c   d
3   101001  d   c
3   101002  b   b
3   101002  c   d
3   101002  d   c
3   101003  b   b
3   101003  c   d
3   101003  d   c
3   101004  b   b
3   101004  c   d
3   101004  d   c

Following Shuffle column in panda dataframe with groupby using df['col4'] = test_df.groupby('ID')['col3'].transform(np.random.permutation) does not work in my case since the shuffled results vary within a given ID, due to the varying entries in col2.


Solution

  • from random import sample
    
    def f(ser):
        elements = list(set(ser))
        replacements = sample(elements, len(elements))
        return ser.replace(elements, replacements)
    
    df['col3'] = df.groupby('ID')['col2'].transform(f)
    

    Example:

    df = pd.DataFrame({'ID': [1, 1, 1, 1, 1, 2, 2, 2, 2],
                       'col2': ['a', 'b', 'a', 'c', 'b', 'a', 'f', 'a', 'f']})
    

    Result:

       ID col2 col3
    0   1    a    a
    1   1    b    c
    2   1    a    a
    3   1    c    b
    4   1    b    c
    5   2    a    f
    6   2    f    a
    7   2    a    f
    8   2    f    a