Search code examples
pythonpandasgroup-bynested

Recursive groupby with quantiles


I have a dataframe of floats

    a           b           c           d           e
0   0.085649    0.236811    0.801274    0.582162    0.094129
1   0.433127    0.479051    0.159739    0.734577    0.113672
2   0.391228    0.516740    0.430628    0.586799    0.737838
3   0.956267    0.284201    0.648547    0.696216    0.292721
4   0.001490    0.973460    0.298401    0.313986    0.891711
5   0.585163    0.471310    0.773277    0.030346    0.706965
6   0.374244    0.090853    0.660500    0.931464    0.207191
7   0.630090    0.298163    0.741757    0.722165    0.218715

I can divide it into quantiles for a single column like so:

def groupby_quantiles(df, column, groups: int):
    quantiles = df[column].quantile(np.linspace(0, 1, groups + 1))
    bins = pd.cut(df[column], quantiles, include_lowest=True)
    return df.groupby(bins)
>>> df.pipe(groupby_quantiles, "a", 2).apply(lambda x: print(x))
          a         b         c         d         e
0  0.085649  0.236811  0.801274  0.582162  0.094129
2  0.391228  0.516740  0.430628  0.586799  0.737838
4  0.001490  0.973460  0.298401  0.313986  0.891711
6  0.374244  0.090853  0.660500  0.931464  0.207191
          a         b         c         d         e
1  0.433127  0.479051  0.159739  0.734577  0.113672
3  0.956267  0.284201  0.648547  0.696216  0.292721
5  0.585163  0.471310  0.773277  0.030346  0.706965
7  0.630090  0.298163  0.741757  0.722165  0.218715

Now, I want to repeat the same operation on each of the groups for the next column. The code becomes ridiculous

>>> (
        df
        .pipe(groupby_quantiles, "a", 2)
        .apply(
            lambda df_group: (
                df_group
                .pipe(groupby_quantiles, "b", 2)
                .apply(lambda x: print(x))
            )
        )
    )
          a         b         c         d         e
0  0.085649  0.236811  0.801274  0.582162  0.094129
6  0.374244  0.090853  0.660500  0.931464  0.207191
          a        b         c         d         e
2  0.391228  0.51674  0.430628  0.586799  0.737838
4  0.001490  0.97346  0.298401  0.313986  0.891711
          a         b         c         d         e
3  0.956267  0.284201  0.648547  0.696216  0.292721
7  0.630090  0.298163  0.741757  0.722165  0.218715
          a         b         c         d         e
1  0.433127  0.479051  0.159739  0.734577  0.113672
5  0.585163  0.471310  0.773277  0.030346  0.706965

My goal is to repeat this operation for as many columns as I want, then aggregate the groups at the end. Here's how the final function could look like and the desired result assuming to aggregate with the mean.

>>> groupby_quantiles(df, columns=["a", "b"], groups=[2, 2], agg="mean")
    a           b           c           d           e
0   0.229947    0.163832    0.730887    0.756813    0.150660
1   0.196359    0.745100    0.364515    0.450392    0.814774
2   0.793179    0.291182    0.695152    0.709190    0.255718
3   0.509145    0.475180    0.466508    0.382462    0.410319

Any ideas on how to achieve this?


Solution

  • Here is a way. First using quantile then cut can be rewrite with qcut. Then using recursive operation similar to this.

    def groupby_quantiles(df, cols, grs, agg_func):
        # to store all the results
        _dfs = []
    
        # recursive function 
        def recurse(_df, depth):
            col = cols[depth]
            gr =  grs[depth]
            # iterate over the groups per quantile
            for _, _dfgr in _df.groupby(pd.qcut(_df[col], gr)):
                if depth != -1: recurse(_dfgr, depth+1) #recursive if not at the last column
                else: _dfs.append(_dfgr.agg(agg_func)) #else perform the aggregate
        
        # using negative depth is easier to acces the right column and quantile
        depth = -len(cols)
        recurse(df, depth) # starts the recursion
        return pd.concat(_dfs, axis=1).T # concat the results and transpose
    
    print(groupby_quantiles(df, cols = ['a','b'], grs = [2,2], agg_func='mean'))
    #           a         b         c         d         e
    # 0  0.229946  0.163832  0.730887  0.756813  0.150660
    # 1  0.196359  0.745100  0.364515  0.450392  0.814774
    # 2  0.793179  0.291182  0.695152  0.709190  0.255718
    # 3  0.509145  0.475181  0.466508  0.382462  0.410318