I have following dataframes in spark sql:
items:
|-- item_id: long
|-- clicks: long
|-- recs: decimal(16,6)
|-- data_timestamp: long
campaigns:
|-- item_id: long
|-- campaign_id: long
And I have this query:
select /*+ BROADCAST(campaigns) */
campaign_id,
items.item_id,
SUM(clicks) AS clicks,
SUM(recs) AS recs ,
MAX(data_timestamp) AS data_timestamp
FROM items
JOIN campaigns
ON items.item_id = campaigns.item_id
GROUP BY campaign_id, items.item_id
DISTRIBUTE BY campaign_id
So basically I want to distribute the data by campaign_id, and have a single shuffle. But it doesn't really work, it shuffles by item_id, campaign_id for the aggregations, and then shuffles again by campaign_id
Is there anyway I can make spark avoid the redundant shuffle?
Thanks
I'ts pretty extrange. I've been able to eliminate the extra shuffle using Dataframe API in pyspark:
items.join(F.broadcast(campaigns), "item_id") \
.repartition("campaign_id") \
.groupBy("item_id", "campaign_id").agg(F.sum("clicks"), F.sum("recs"), F.max("data_timestamp")) \
.explain("simple")
== Physical Plan ==
*(3) HashAggregate(keys=[item_id#1108L, campaign_id#1117L], functions=[sum(clicks#1109L), sum(recs#1110), max(data_timestamp#1111)])
+- *(3) HashAggregate(keys=[item_id#1108L, campaign_id#1117L], functions=[partial_sum(clicks#1109L), partial_sum(recs#1110), partial_max(data_timestamp#1111)])
+- Exchange hashpartitioning(campaign_id#1117L, 200), REPARTITION_BY_COL, [plan_id=3443]
+- *(2) Project [item_id#1108L, clicks#1109L, recs#1110, data_timestamp#1111, campaign_id#1117L]
+- *(2) BroadcastHashJoin [item_id#1108L], [item_id#1116L], Inner, BuildRight, false
:- *(2) Filter isnotnull(item_id#1108L)
: +- *(2) Scan ExistingRDD[item_id#1108L,clicks#1109L,recs#1110,data_timestamp#1111]
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=3438]
+- *(1) Filter isnotnull(item_id#1116L)
+- *(1) Scan ExistingRDD[item_id#1116L,campaign_id#1117L]
As you can see there is only one exchange by campaign_id
. And then there is a HashAggregate
by campaign_id
and item_id
.
If I try to replicate it with SQL, it includes an extra exchange:
spark.sql("""
select /*+ REPARTITION(campaign_id), BROADCAST(campaigns) */
campaign_id,
items.item_id,
SUM(clicks) AS clicks,
SUM(recs) AS recs ,
MAX(data_timestamp) AS data_timestamp
FROM items
JOIN campaigns
ON items.item_id = campaigns.item_id
GROUP BY campaign_id, items.item_id
""").explain("simple")
== Physical Plan ==
Exchange hashpartitioning(campaign_id#61L, 200), REPARTITION_BY_COL, [plan_id=2971]
+- *(3) HashAggregate(keys=[campaign_id#61L, item_id#52L], functions=[sum(clicks#53L), sum(recs#54), max(data_timestamp#55)])
+- Exchange hashpartitioning(campaign_id#61L, item_id#52L, 200), ENSURE_REQUIREMENTS, [plan_id=2967]
+- *(2) HashAggregate(keys=[campaign_id#61L, item_id#52L], functions=[partial_sum(clicks#53L), partial_sum(recs#54), partial_max(data_timestamp#55)])
+- *(2) Project [item_id#52L, clicks#53L, recs#54, data_timestamp#55, campaign_id#61L]
+- *(2) BroadcastHashJoin [item_id#52L], [item_id#60L], Inner, BuildRight, false
:- *(2) Filter isnotnull(item_id#52L)
: +- *(2) Scan ExistingRDD[item_id#52L,clicks#53L,recs#54,data_timestamp#55]
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=2961]
+- *(1) Filter isnotnull(item_id#60L)
+- *(1) Scan ExistingRDD[item_id#60L,campaign_id#61L]
The Dataframe API is producing a different physical plan.