Search code examples
pandaspivotmaxmulti-indexargmax

Pandas hierarchical pivot get column with max


df.head().info()

RangeIndex: 5 entries, 0 to 4
Data columns (total 4 columns):
id    5 non-null object
date-hr               5 non-null object
channel           5 non-null object
hr                5 non-null int64
dtypes: int64(1), object(3)

Actual date-hr looks something like

'2017-02-14--15' 

id is a string

I have a df like:

User-ID | Date-hr | Channel | Hr

U1       D1-10      C1        10
U1       D1-11      C2        11
U1       D1-10      C1        10
U1       D1-10      C3        10
U1       D1-10      C1        10
U1       D1-11      C3        11
U1       D1-11      C2        11

..

when I apply pivot operation with user-id as index and columns as

['date-hr', 'channel']

using count as the aggregation function.

I get 1 row for every user with the primary index as date-hr and all channels under that one date-hr value like:

    D1-10     D1-11 .....

    C1  C3    C2 C3 .....

U1  3    1    2   1 .....

Now what I require is max channel under every 'date-hr' with the count

    D1-10   D1-11 .....

    C1      C2    .....

U1  (C1,3)  (C2,2) .....

I can't figure it out how to get this transformation from my data.


Solution

  • You can create custom function:

    print (df)
      User-ID Date-hr Channel  Hr
    0      U1   D1-10      C1  10
    1      U1   D1-11      C2  11
    2      U1   D1-10      C1  10
    3      U1   D1-10      C3  10
    4      U2   D1-10      C1  10
    5      U2   D1-11      C3  11
    6      U2   D1-11      C2  11
    6      U4   D7-11      C2  11
    
    df = df.groupby(['User-ID','Date-hr', 'Channel'])['Hr'].count().unstack([1,2], fill_value=0)
    print (df)
    Date-hr D1-10    D1-11    D7-11
    Channel    C1 C3    C2 C3    C2
    User-ID                        
    U1          2  1     1  0     0
    U2          1  0     1  1     0
    U4          0  0     0  0     1
    
    def f(x):
        c = x.idxmax(axis=1).str[1]
        m = x.max(axis=1)
        s = pd.Series((list(zip(c, m))), index=x.index)
        return (s)
    
    df = df.groupby(axis=1, level=0).apply(f)
    print (df)
    Date-hr    D1-10    D1-11    D7-11
    User-ID                           
    U1       (C1, 2)  (C2, 1)  (C2, 0)
    U2       (C1, 1)  (C2, 1)  (C2, 0)
    U4       (C1, 0)  (C2, 0)  (C2, 1)