Search code examples
apache-spark-sqlquery-optimization

Optimize spark sql query with aggregate over join


I have following dataframes in spark sql:

items:

|-- item_id: long 
|-- clicks: long 
|-- recs: decimal(16,6) 
|-- data_timestamp: long 

campaigns:

 |-- item_id: long 
 |-- campaign_id: long 

And I have this query:

select  /*+ BROADCAST(campaigns) */
 campaign_id,
 items.item_id,
 SUM(clicks) AS clicks,
 SUM(recs) AS recs , 
 MAX(data_timestamp) AS data_timestamp 
FROM items 
JOIN campaigns
ON items.item_id = campaigns.item_id
GROUP BY campaign_id, items.item_id
DISTRIBUTE BY campaign_id

So basically I want to distribute the data by campaign_id, and have a single shuffle. But it doesn't really work, it shuffles by item_id, campaign_id for the aggregations, and then shuffles again by campaign_id

Is there anyway I can make spark avoid the redundant shuffle?

Thanks


Solution

  • I'ts pretty extrange. I've been able to eliminate the extra shuffle using Dataframe API in pyspark:

    items.join(F.broadcast(campaigns), "item_id") \
      .repartition("campaign_id") \
      .groupBy("item_id", "campaign_id").agg(F.sum("clicks"), F.sum("recs"), F.max("data_timestamp")) \
      .explain("simple")
    
    == Physical Plan ==
    *(3) HashAggregate(keys=[item_id#1108L, campaign_id#1117L], functions=[sum(clicks#1109L), sum(recs#1110), max(data_timestamp#1111)])
    +- *(3) HashAggregate(keys=[item_id#1108L, campaign_id#1117L], functions=[partial_sum(clicks#1109L), partial_sum(recs#1110), partial_max(data_timestamp#1111)])
       +- Exchange hashpartitioning(campaign_id#1117L, 200), REPARTITION_BY_COL, [plan_id=3443]
          +- *(2) Project [item_id#1108L, clicks#1109L, recs#1110, data_timestamp#1111, campaign_id#1117L]
             +- *(2) BroadcastHashJoin [item_id#1108L], [item_id#1116L], Inner, BuildRight, false
                :- *(2) Filter isnotnull(item_id#1108L)
                :  +- *(2) Scan ExistingRDD[item_id#1108L,clicks#1109L,recs#1110,data_timestamp#1111]
                +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=3438]
                   +- *(1) Filter isnotnull(item_id#1116L)
                      +- *(1) Scan ExistingRDD[item_id#1116L,campaign_id#1117L]
    

    As you can see there is only one exchange by campaign_id. And then there is a HashAggregate by campaign_id and item_id.

    If I try to replicate it with SQL, it includes an extra exchange:

    spark.sql("""
      select  /*+ REPARTITION(campaign_id), BROADCAST(campaigns) */
      campaign_id,
      items.item_id,
      SUM(clicks) AS clicks,
      SUM(recs) AS recs , 
      MAX(data_timestamp) AS data_timestamp 
      FROM items 
      JOIN campaigns
      ON items.item_id = campaigns.item_id
      GROUP BY campaign_id, items.item_id
    """).explain("simple")
    
    == Physical Plan ==
    Exchange hashpartitioning(campaign_id#61L, 200), REPARTITION_BY_COL, [plan_id=2971]
    +- *(3) HashAggregate(keys=[campaign_id#61L, item_id#52L], functions=[sum(clicks#53L), sum(recs#54), max(data_timestamp#55)])
       +- Exchange hashpartitioning(campaign_id#61L, item_id#52L, 200), ENSURE_REQUIREMENTS, [plan_id=2967]
          +- *(2) HashAggregate(keys=[campaign_id#61L, item_id#52L], functions=[partial_sum(clicks#53L), partial_sum(recs#54), partial_max(data_timestamp#55)])
             +- *(2) Project [item_id#52L, clicks#53L, recs#54, data_timestamp#55, campaign_id#61L]
                +- *(2) BroadcastHashJoin [item_id#52L], [item_id#60L], Inner, BuildRight, false
                   :- *(2) Filter isnotnull(item_id#52L)
                   :  +- *(2) Scan ExistingRDD[item_id#52L,clicks#53L,recs#54,data_timestamp#55]
                   +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, false]),false), [plan_id=2961]
                      +- *(1) Filter isnotnull(item_id#60L)
                         +- *(1) Scan ExistingRDD[item_id#60L,campaign_id#61L]
    

    The Dataframe API is producing a different physical plan.