I've a data frame in pyspark
data = [
("2023-09-03", "NA", "US", 1, "us1"),
("2023-09-03", "NA", "US", 1, "us2"),
("2023-09-03", "NA", "MX", 1, "mx1"),
("2023-09-03", "NA","CA", 2, "ca1"),
("2023-09-03", "NA","CA", 2, "ca2")
]
columns = ["date", "region" ,"country", "id", "person"]
df = spark.createDataFrame(data, columns).drop("country")
df.display()
date region id person 2023-09-03 NA 1 us1 2023-09-03 NA 1 us2 2023-09-03 NA 1 mx1 2023-09-03 NA 2 ca1 2023-09-03 NA 2 ca2
I am trying to list all person that shares the same columns i.e. date, region, and id. I am expecting a result like
id date region person_x person_y 1 2023-09-03 NA us1 us2 1 2023-09-03 NA us1 mx1 1 2023-09-03 NA us2 us1 1 2023-09-03 NA us2 mx1 1 2023-09-03 NA mx1 us1 1 2023-09-03 NA mx1 us2 2 2023-09-03 NA ca1 ca2 2 2023-09-03 NA ca2 ca1
I have tried it with a self-join but is there an efficient solution for this? here what I've tried
join_cols = ["id", "date", "region"]
df_x = df.select(join_cols + ["person"]).withColumnRenamed("person", "person_x")
df_y = df.select(join_cols + ["person"]).withColumnRenamed("person", "person_y")
df_mix = df_x.join(df_y, on=join_cols, how="inner").filter("person_x != person_y")
df_mix.display()
Using a group by
instead of a join
:
Step 1: Group the data by join_cols
and use collect_list as aggregation function (column persons
in the code below).
Step 2: Use this answer to get all permutations of the persons
array (-> combinations
).
Step 3: Explode the combinations
column (-> pairs
)
Step 4: Split the pairs into two columns person_x
and person_y
.
df.groupBy(join_cols).agg(F.collect_list('person').alias('persons')) \
.withColumn(
"combinations",
F.filter(
F.transform(
F.flatten(F.transform(
c:="persons",
lambda x: F.arrays_zip(F.array_repeat(x, F.size(c)), c)
)),
lambda x: F.array(x["0"], x[c])
),
lambda x: x[0] != x[1]
)
)\
.withColumn('pairs', F.explode('combinations')) \
.withColumn('person_x', F.col('pairs')[0]) \
.withColumn('person_y', F.col('pairs')[1]) \
.drop('persons', 'combinations', 'pairs') \
.show(truncate=False)
Output:
+---+----------+------+--------+--------+
|id |date |region|person_x|person_y|
+---+----------+------+--------+--------+
|1 |2023-09-03|NA |us1 |us2 |
|1 |2023-09-03|NA |us1 |mx1 |
|1 |2023-09-03|NA |us2 |us1 |
|1 |2023-09-03|NA |us2 |mx1 |
|1 |2023-09-03|NA |mx1 |us1 |
|1 |2023-09-03|NA |mx1 |us2 |
|2 |2023-09-03|NA |ca1 |ca2 |
|2 |2023-09-03|NA |ca2 |ca1 |
+---+----------+------+--------+--------+