Search code examples
pythonpandasgroup-bymaxtransform

min/max value of a column based on values of another column, grouped by and transformed in pandas


I'd like to know if I can do all this in one line, rather than multiple lines.

my dataframe:

    import pandas as pd
df = pd.DataFrame({'ID' : [1,1,1,1,1,1,2,2,2,2,2,2]
    ,'A': [1, 2, 3, 10, np.nan, 5 , 20, 6, 7, np.nan, np.nan, np.nan]
    , 'B': [0,1,1,0,1,1,1,1,1,0,1,0]
    , 'desired_output' : [5,5,5,5,5,5,20,20,20,20,20,20]})
df

    ID  A       B   desired_output
0   1   1.0     0   5
1   1   2.0     1   5
2   1   3.0     1   5
3   1   10.0    0   5
4   1   NaN     1   5
5   1   5.0     1   5
6   2   20.0    1   20
7   2   6.0     1   20
8   2   7.0     1   20
9   2   NaN     0   20
10  2   NaN     1   20
11  2   NaN     0   20

I'm trying to find the maximum value of column A, for values of column B == 1, group by column ID, and transform the results directly so that the value is back in the dataframe without extra merging et al.

something like the following (but without getting errors!)

df['desired_output'] =  df.groupby('ID').A.where(df.B == 1).transform('max')  ## this gives error

The max function should ignore the NaNs as well. I wonder if I'm trying too much in one line, but one can hope there is a way for a beautiful code.

EDIT: I can get a very similar output by changing the where clause:

df['desired_output'] =  df.where(df.B == 1).groupby('ID').A.transform('max') ## this works but output is not what i want

but the output is not exactly what I want. desired_output should not have any NaN, unless all values of A are NaN for when B == 1.


Solution

  • Here is a way to do it:

    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame({
            'ID' : [1,1,1,1,1,1,2,2,2,2,2,2],
            'A': [1, 2, 3, 10, np.nan, 5 , 20, 6, 7, np.nan, np.nan, np.nan], 
            'B': [0,1,1,0,1,1,1,1,1,0,1,0], 
            'desired_output' : [5,5,5,5,5,5,20,20,20,20,20,20]
        })
    
    df['output'] = df[df.B == 1].groupby('ID').A.max()[df.ID].array
    
    df
    

    Result:

    
        ID     A  B  desired_output  output
    0    1   1.0  0               5     5.0
    1    1   2.0  1               5     5.0
    2    1   3.0  1               5     5.0
    3    1  10.0  0               5     5.0
    4    1   NaN  1               5     5.0
    5    1   5.0  1               5     5.0
    6    2  20.0  1              20    20.0
    7    2   6.0  1              20    20.0
    8    2   7.0  1              20    20.0
    9    2   NaN  0              20    20.0
    10   2   NaN  1              20    20.0
    11   2   NaN  0              20    20.0
    
    

    Decomposition:

    df[df.B == 1]   # start by filtering on B
    .groupby('ID')  # group by ID
    .A.max()        # get max values in column A
    [df.ID]         # recast the result on ID series shape
    .array          # fetch the raw values from the Series
    

    Important note: it relies on the fact that the index is as in the given example, that is, sorted, starting from 0, with a 1 increment. You will have to reset_index() of your DataFrame before this operation when this is not the case.