Search code examples
pythonpandasjoinmergecombinations

How to join dataframe on itself creating all combinations inside groups


Some mock data:

pd.DataFrame({'date': {0: Timestamp('2021-08-01 '),
  1: Timestamp('2022-08-01 '),
  2: Timestamp('2021-08-02 '),
  3: Timestamp('2021-08-01 '),
  4: Timestamp('2022-08-01 '),
  5: Timestamp('2022-08-01 '),
  6: Timestamp('2022-08-01 ')                   },
 'product_nr': {0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7},
 'Category': {0:  'Cars', 1: 'Cars', 2: 'Cats', 3: 'Dogs', 4: 'Dogs', 5: 'Cats', 6 :'Cats'},
 'price': {0: '34',
  1: '24',
  2: '244',
  3: '284',
  4: '274',
  5: '354',
  6 : '250'}} )

How do I do an inner join on the same dataframe with a specific condition? I want to compare prices between rows that are the same category. Desired output:

pd.DataFrame({
 'product_nr': {0: 1,  1: 3,  2: 5, 3: 7, 4:7},
 'Category': {0:  'Cars',  1: 'Cats', 2: 'Dogs', 3:'Cats', 4:'Cats'},
 'price': {0: '34',
  1: '244',
  2: '274',
  3: '250',
  4: '250'},
 'product_to_be_compared' : {0: 2,  1: 6,  2: 4, 3:3 , 4:6}
} )

I.e., I want to do an inner join / cross join (not sure what's most suitable). I have a large dataframe and I want to pair rows together if they are the same category and date. Ideally, I would prefer to remove duplicated pairs, meaning my desired output would be 4 rows.


Solution

  • From your questions I know you're familiar with PySpark. This is how it could be done using PySpark dataframes. Even though it uses external itertools library, it should perform well, because that part resides inside a pandas_udf which is vectorized for performance.

    Input df:

    import pandas as pd
    
    pdf = pd.DataFrame({
        'date': {
            0: pd.Timestamp('2021-08-01'),
            1: pd.Timestamp('2021-08-01'),
            2: pd.Timestamp('2021-08-02'),
            3: pd.Timestamp('2021-08-03'),
            4: pd.Timestamp('2021-08-03'),
            5: pd.Timestamp('2021-08-02'),
            6: pd.Timestamp('2021-08-02')
        },
        'product_nr': {0: '1', 1: '2', 2: '3', 3: '4', 4: '5', 5: '6', 6: '7'},
        'Category': {0:  'Cars', 1: 'Cars', 2: 'Cats', 3: 'Dogs', 4: 'Dogs', 5: 'Cats', 6 :'Cats'},
        'price': {
            0: '34',
            1: '24',
            2: '244',
            3: '284',
            4: '274',
            5: '354',
            6 : '250'
        }
    })
    df = spark.createDataFrame(pdf)
    

    Script:

    from pyspark.sql import functions as F
    from itertools import combinations
    
    @F.pandas_udf('array<array<string>>')
    def arr_combinations(c: pd.Series) -> pd.Series:
        return c.apply(lambda x: list(combinations(x, 2)))
    
    df2 = df.groupBy('Category', 'date').agg(F.collect_list('product_nr').alias('ps'))
    df2 = df2.withColumn('ps', F.explode(arr_combinations('ps')))
    df2 = df2.select(
        'Category', 'date',
        F.col('ps')[0].alias('product_nr'),
        F.col('ps')[1].alias('product_to_be_compared')
    )
    df3 = df.join(df2, ['product_nr', 'Category', 'date'])
    
    df3.show()
    # +----------+--------+-------------------+-----+----------------------+
    # |product_nr|Category|               date|price|product_to_be_compared|
    # +----------+--------+-------------------+-----+----------------------+
    # |         3|    Cats|2021-08-02 00:00:00|  244|                     7|
    # |         3|    Cats|2021-08-02 00:00:00|  244|                     6|
    # |         1|    Cars|2021-08-01 00:00:00|   34|                     2|
    # |         6|    Cats|2021-08-02 00:00:00|  354|                     7|
    # |         4|    Dogs|2021-08-03 00:00:00|  284|                     5|
    # +----------+--------+-------------------+-----+----------------------+
    

    If you want to compare prices directly in this table, use the following:

    from pyspark.sql import functions as F
    from itertools import combinations
    
    @F.pandas_udf('array<array<array<string>>>')
    def arr_combinations(c: pd.Series) -> pd.Series:
        return c.apply(lambda x: list(combinations(x, 2)))
    
    df2 = df.groupBy('Category', 'date').agg(F.collect_list(F.array('product_nr', 'price')).alias('ps'))
    df2 = df2.withColumn('ps', F.explode(arr_combinations('ps')))
    df2 = df2.select(
        F.col('ps')[0][0].alias('product_nr'),
        'Category',
        'date',
        F.col('ps')[0][1].alias('product_price'),
        F.col('ps')[1][0].alias('product_to_be_compared'),
        F.col('ps')[1][1].alias('product_to_be_compared_price'),
    )
    df2.show()
    # +----------+--------+-------------------+-------------+----------------------+----------------------------+
    # |product_nr|Category|               date|product_price|product_to_be_compared|product_to_be_compared_price|
    # +----------+--------+-------------------+-------------+----------------------+----------------------------+
    # |         1|    Cars|2021-08-01 00:00:00|           34|                     2|                          24|
    # |         3|    Cats|2021-08-02 00:00:00|          244|                     6|                         354|
    # |         3|    Cats|2021-08-02 00:00:00|          244|                     7|                         250|
    # |         6|    Cats|2021-08-02 00:00:00|          354|                     7|                         250|
    # |         4|    Dogs|2021-08-03 00:00:00|          284|                     5|                         274|
    # +----------+--------+-------------------+-------------+----------------------+----------------------------+