Search code examples
pythonpandasdataframegroup-by

How to get groupby a dataframe and aggregate values when columns to be aggregated are discovered dynamically?


I have a dataframe which has a key column, some value columns and some timestamp columns. For some keys, there might be multiple rows with differing values in value and timestamp columns.

I want to find the keys where there are multiple rows, the columns which have distinct values for a particular keys and then I wants to get an aggregate where, for a key, the columns with differing values are summed or averaged and max or min values are chosen for timestamp columns. The sample data is below:

data = [['key1', 10, 10, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), 'A'], 
        ['key1', 10, 20, 10, pd.Timestamp('2024-05-11'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-06')],
        ['key1', 10, 30, 10, pd.Timestamp('2024-05-11'), pd.Timestamp('2024-05-08'), pd.Timestamp('2024-05-12')],
        ['key2', 10, 10, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09')],
        ['key2', 12, 10, 10, pd.Timestamp('2024-05-13'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09')],
        ['key3', 14, 11, 17, pd.Timestamp('2024-06-09'), pd.Timestamp('2024-05-04'), pd.Timestamp('2024-05-01')],
        ['key4', 10, 10, 12, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-11'), pd.Timestamp('2024-05-29')],
        ['key5', 10, 10, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-11')],
        ['key5', 10, 10, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-11')],
        ['key5', 12, 11, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-11')],
        ['key5', 10, 11, 10, pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-09'), pd.Timestamp('2024-05-11')]
       ]
columns = ['Key', 'Value1', 'Value2', 'Value3', 'Timestamp1', 'Timestamp2', 'Timestamp3', 'ID']
sample_df = pd.DataFrame(data, columns=columns)

the actual file has millions of rows and 100s of columns.

I want to get an output which looks something like this

Key CountOfrows Value1 Value2 Timestamp1 Timestamp2 Timestamp3 Id
Key1 3 10 60 '2024-05-09' '2024-05-09' '2024-05-12' 'A'
Key2 2 22 10 '2024-05-09' '2024-05-09' '2024-05-09' 'B'
Key5 4 42 42 '2024-05-09' '2024-05-09' '2024-05-11' 'B'

Here the values are aggregated by averaging if they are same for a column and summed if they are different. The timestamps are aggregated based on min for timestamp1, max for timestamp2 and timestamp3. for Id column, they are sorted by Value1 column and ID corresponding to max Value1 is taken

since value3 has the same values in all instances of a key, it is not included in the final table.

I have been able to get midway through where i have the keys with only those columns where values change.

multiple_rows_sample = sample_df.groupby(['Key']).size().reset_index(name='counts')
multiple_rows_sample = multiple_rows_sample[multiple_rows_sample['counts']>1]

mult_val_cols_sample = pd.DataFrame()


for index, row in multiple_rows_sample.iterrows():
    joined_slice = sample_df[(sample_df['Key']==row['Key'])]
    count_slice = row.to_frame().transpose().reset_index(drop=True)
    count_slice['key']=1
    diff_cols = cols_having_unique(joined_slice)
    diff_cols['key']=1
    output_df = pd.merge(count_slice, diff_cols, how='outer')
    output_df = output_df.drop('key', axis=1)
    mult_val_cols_sample = pd.concat([mult_val_cols_sample, output_df], ignore_index=True)

The mult_val_cols_sample table has the key column and only those columns where the values have changed for at least one of the keys. Now how do i run a groupby on these columns when i dont know their name beforehand?

Any help is appreciated.


Solution

  • You can use a custom groupby.agg and post-process the output:

    • to filter out the rows/columns
    • to replace the index (from idxmax) by the actual ID
    g = sample_df.groupby('Key', as_index=False)
    
    def avg_or_sum(vals):
        if vals.nunique() == 1:
            return vals.mean()
        else:
            return vals.sum()
    
    out = (g
           # aggregate with custom functions
           .agg(**{'CountOfrows': ('Key', 'size'),
                   'Value1': ('Value1', avg_or_sum),
                   'Value2': ('Value2', avg_or_sum),
                   'Value3': ('Value3', avg_or_sum),
                   'Timestamp1': ('Timestamp1', 'min'),
                   'Timestamp2': ('Timestamp2', 'max'),
                   'Timestamp3': ('Timestamp3', 'max'),
                   'ID': ('Value1', 'idxmax') # this will need to be post-processed
                  })
           # only keep the rows with more than 1 item
           .query('CountOfrows > 1')
           # filter out the columns with all identical values within all groups
           .loc[:, lambda x: g.nunique().ne(1).any()
                              .reindex(x.columns, fill_value=True)]
           # replace the index with the actual ID
           .assign(ID=lambda d: d['ID'].map(sample_df['ID']))
          )
    

    Output:

        Key  CountOfrows  Value1  Value2 Timestamp1 Timestamp2 Timestamp3   ID
    0  key1            3    10.0    60.0 2024-05-09 2024-05-09 2024-05-12    A
    1  key2            2    22.0    10.0 2024-05-09 2024-05-09 2024-05-09  NaN
    4  key5            4    42.0    42.0 2024-05-09 2024-05-09 2024-05-11  NaN
    

    Note that if you have hundreds of columns you will need to generate the dictionary passed to agg programmatically. Since the exact logic is unspecified, here is a generic example:

    # first define the new/fixed columns
    dic = {'CountOfrows': ('Key', 'size'),
           'ID': ('Value1', 'idxmax')
          }
    
    # then programmatically define the repeated columns
    # for example by matching by name
    for col in sample_df.columns:
        if col.startswith('Value'):
            dic[col] = (col, avg_or_sum)
        elif col.startswith('Timestamp'):
            # let define max by default
            # we'll change it for selected cols later
            dic[col] = (col, 'max')
    
    # finally, fine-tune some specific changes
    dic['Timestamp1'] = ('Timestamp1', 'min')
    

    Then in the above code, use:

    .agg(**dic)