Search code examples
dataframeapache-sparkpysparkapache-spark-sql

Generate all possible combination of column elements which share same date (columns) in pyspark


I've a data frame in pyspark

data = [
    ("2023-09-03", "NA", "US", 1, "us1"),
    ("2023-09-03", "NA", "US", 1, "us2"),
    ("2023-09-03", "NA", "MX", 1, "mx1"),
    ("2023-09-03", "NA","CA", 2, "ca1"),
    ("2023-09-03", "NA","CA", 2, "ca2")
]

columns = ["date", "region" ,"country", "id", "person"]

df = spark.createDataFrame(data, columns).drop("country")
df.display()
date  region  id  person
2023-09-03    NA  1   us1
2023-09-03    NA  1   us2
2023-09-03    NA  1   mx1
2023-09-03    NA  2   ca1
2023-09-03    NA  2   ca2

I am trying to list all person that shares the same columns i.e. date, region, and id. I am expecting a result like

id    date    region  person_x    person_y
1 2023-09-03  NA  us1 us2
1 2023-09-03  NA  us1 mx1
1 2023-09-03  NA  us2 us1
1 2023-09-03  NA  us2 mx1
1 2023-09-03  NA  mx1 us1
1 2023-09-03  NA  mx1 us2
2 2023-09-03  NA  ca1 ca2
2 2023-09-03  NA  ca2 ca1

I have tried it with a self-join but is there an efficient solution for this? here what I've tried

join_cols = ["id", "date", "region"]
df_x = df.select(join_cols + ["person"]).withColumnRenamed("person", "person_x")
df_y = df.select(join_cols + ["person"]).withColumnRenamed("person", "person_y")

df_mix = df_x.join(df_y, on=join_cols, how="inner").filter("person_x != person_y")
df_mix.display()

Solution

  • Using a group by instead of a join:

    Step 1: Group the data by join_cols and use collect_list as aggregation function (column persons in the code below).

    Step 2: Use this answer to get all permutations of the persons array (-> combinations).

    Step 3: Explode the combinations column (-> pairs)

    Step 4: Split the pairs into two columns person_x and person_y.

    df.groupBy(join_cols).agg(F.collect_list('person').alias('persons')) \
        .withColumn(
            "combinations",
            F.filter(
                F.transform(
                    F.flatten(F.transform(
                        c:="persons",
                        lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
                    )),
                    lambda x: F.array(x["0"], x[c])
                ),
                lambda x: x[0] != x[1]
            )
            )\
        .withColumn('pairs', F.explode('combinations')) \
        .withColumn('person_x', F.col('pairs')[0]) \
        .withColumn('person_y', F.col('pairs')[1]) \
        .drop('persons', 'combinations', 'pairs') \
        .show(truncate=False)
    

    Output:

    +---+----------+------+--------+--------+
    |id |date      |region|person_x|person_y|
    +---+----------+------+--------+--------+
    |1  |2023-09-03|NA    |us1     |us2     |
    |1  |2023-09-03|NA    |us1     |mx1     |
    |1  |2023-09-03|NA    |us2     |us1     |
    |1  |2023-09-03|NA    |us2     |mx1     |
    |1  |2023-09-03|NA    |mx1     |us1     |
    |1  |2023-09-03|NA    |mx1     |us2     |
    |2  |2023-09-03|NA    |ca1     |ca2     |
    |2  |2023-09-03|NA    |ca2     |ca1     |
    +---+----------+------+--------+--------+