Search code examples
pythonapache-sparkpysparkaggregateintersection

How to make intersection table on PySpark dataframe


Here's my dataset

Name      Order
A         Coffee
A         Tea
A         Burger
A         Fried Chicken
B         Coffee
B         Tea
B         Hot Dog
B         Fried Chicken
C         Coffee
C         Hot Dog
C         Fried Chicken
D         Tea
D         Burger

Here's the Info

Food = ['Hot Dog', 'Burger', 'Fried Chicken'] # Rows
Drink = ['Coffee', 'Tea'] # Colums

Here's the expected output

            Hot Dog    Burger   Fried Chicken
Tea             0.5        1             0.66
Coffee            1      0.5               1

Hot Dog x Tea is 0.5 because among 2 People Buying Hot Dog only one Buying Tea etc.


Solution

  • Somehow it feels there should be a nice way to do it, maybe using crosstab...

    But I could only think of the following quite big algorithm. Even though it uses external itertools library, it should perform well, because that part resides inside a pandas_udf which is vectorized for performance.

    Input:

    from pyspark.sql import functions as F, Window as W
    import pandas as pd
    from itertools import combinations
    df = spark.createDataFrame(
        [('A', 'Coffee'),
         ('A', 'Tea'),
         ('A', 'Burger'),
         ('A', 'Fried Chicken'),
         ('B', 'Coffee'),
         ('B', 'Tea'),
         ('B', 'Hot Dog'),
         ('B', 'Fried Chicken'),
         ('C', 'Coffee'),
         ('C', 'Hot Dog'),
         ('C', 'Fried Chicken'),
         ('D', 'Tea'),
         ('D', 'Burger')],
        ['Name', 'Order'])
    

    Script:

    @F.pandas_udf('array<array<string>>')
    def pudf(c: pd.Series) -> pd.Series:
        return c.apply(lambda x: list(combinations(x, 2)))
    
    df = df.groupBy('Name').agg(F.collect_set('Order').alias('Order'))
    beverages = ['Coffee', 'Tea']
    df = df.withColumn('Order', F.transform(pudf('Order'), lambda x: F.when(x[1].isin(beverages), F.array(x[1], x[0])).otherwise(x)))
    df = df.withColumn('Order', F.filter('Order', lambda x: x[0].isin(beverages) & ~x[1].isin(beverages)))
    df = df.withColumn('Order', F.explode('Order'))
    df = df.select('Name', F.col('Order')[0].alias('beverage'), F.col('Order')[1].alias('food'))
    df = df.withColumn('food_cnt', F.size(F.collect_set('Name').over(W.partitionBy('food'))))
    df = (df
        .groupBy('beverage')
        .pivot('food')
        .agg(F.round(F.count(F.lit(1)) / F.first('food_cnt'), 2))
    )
    df.show()
    # +--------+------+-------------+-------+
    # |beverage|Burger|Fried Chicken|Hot Dog|
    # +--------+------+-------------+-------+
    # |     Tea|   1.0|         0.67|    0.5|
    # |  Coffee|   0.5|          1.0|    1.0|
    # +--------+------+-------------+-------+