Search code examples
pythonpandasdataframeaggregateheatmap

Aggregate dataframe by dimension pairs for heatmap


I have a table with rows being individual visits from users to an online shop. There are multiple columns that are attributes of the visit (could be booleans or categorical with more than 2 possible values). And there is a column that counts how many items were bought during that visit.

I want to create a table which summarises by every attribute-pair the average number of items bought per visit. I.e. for a given attribute pair value, sum up the number of items bought and divide by the number of rows for this attribute pair. I would then use this table to visualise it as a heat map.

This is similar to preparing separate heatmaps for every attribute-pair. But would want to do this in one table or heatmap for easier digestibility.

Note than along the diagonal when the same attributes are being compared the result is always 1, or when mutually exclusive attributes are compared the result is NA. The cells left empty are the same as the ones filled in above the diagonal.

Input table:

visit_id age is_website is_US items_bought
aa young true true 0
ab young false false 2
ac old true true 0
ad old true false 3

Desired output table:

age young age old is_website true is_website false is_US true is_US false
age young 1 NA 0 2 0 2
age old NA 1 1.5 null division 0 3
is_website true 1 NA 0 3
is_website false NA 1 null division 2
is_US true 1 NA
is_US false NA 1

What I tried:

  1. hard code a list of the attributes (i.e. column names that are attributes)
data = {
    'visit_id': ['aa', 'ab', 'ac', 'ad'],
    'age': ['young', 'young', 'old', 'old'],
    'is_website': [True, False, True, True],
    'is_US': [True, False, True, False],
    'items_bought': [0, 2, 0, 3]
}

import pandas as pd
df1 = pd.DataFrame(data)

dim = ['age', 'is_website', 'is_US']
  1. create a list of all the unique attribute pairs without repetition using combinations from itertools
from itertools import combinations
dim_pairs = list(combinations(dim, 2))
  1. run a for loop:
  • group the dataframe by the attribute pairs, sum up the items bought (total_items_bought) and count the number of rows (total_visits)
  • add a new column to the resulting dataframe items_bought_per_visit which is total_items_bought / total_visits
  • Save the resulting dataframe to a dictionary
dfs = {}

for x in range(len(dim_pairs)):
    grouped = df1.groupby([dim_pairs[x][0], dim_pairs[x][1]]).agg({'items_bought': 'sum', 'visit_id': 'count'}).reset_index()

    grouped['items_bought_per_visit'] = grouped['items_bought'] / grouped['visit_id']

    pivot_df = grouped.pivot_table(index=dim_pairs[x][0], columns=dim_pairs[x][1], values='items_bought_per_visit', aggfunc='sum').fillna(0)

    df_name = f"pivot_df{x}"
    dfs[df_name] = pivot_df

So I have a dictionary of dataframes which has all the values I'd need to fill in the single desired output table, but not sure how. Or not sure if there's an easier way.


Solution

  • You could try to first reshape df1 with melt, then use corr with a custom function:

    tmp = (df1
       .melt(df1.columns.difference(dim))
       .assign(variable=lambda d: d['variable']+' '+d['value'].astype(str))
       .pivot(index='visit_id', columns='variable', values='items_bought')
    )
    
    def f(a, b):
        m = a == b
        return a[m].mean()
    
    out = tmp.corr(f)
    

    Output:

    variable          age old  age young  is_US False  is_US True  is_website False  is_website True
    variable                                                                                        
    age old               1.0        NaN          3.0         0.0               NaN              1.5
    age young             NaN        1.0          2.0         0.0               2.0              0.0
    is_US False           3.0        2.0          1.0         NaN               2.0              3.0
    is_US True            0.0        0.0          NaN         1.0               NaN              0.0
    is_website False      NaN        2.0          2.0         NaN               1.0              NaN
    is_website True       1.5        0.0          3.0         0.0               NaN              1.0
    

    Intermediate tmp:

    variable  age old  age young  is_US False  is_US True  is_website False  is_website True
    visit_id                                                                                
    aa            NaN        0.0          NaN         0.0               NaN              0.0
    ab            NaN        2.0          2.0         NaN               2.0              NaN
    ac            0.0        NaN          NaN         0.0               NaN              0.0
    ad            3.0        NaN          3.0         NaN               NaN              3.0