Search code examples
pythonpysparkpartitionevenly

Distribute dataset evenly by range of id in PySpark


I'm very new to PySpark and been having a challenge with partitioning data.

I have 2 datasets:

  • Ad data set (very big) with ad_id and some attribute columns
  • Ad transactions data set (smaller), includes ad_id and transaction date

It appears to me that i can only partition by ad_id, my question is: how can i evenly distribute data by the ranges of ad_id for both data set, so that when i need to compute a join between the 2 sets, it'll be faster?

here is what i'm trying to do:

ads.write.partitionBy("ad_id").mode('overwrite').parquet(os.path.join(output_data, 'ads_table'))

Thanks!


Solution

  • Use Bucketing

    If you are using spark v2.3 and greater, you can use bucketing to avoid the shuffle that takes place on the join after the write.

    With bucketing you can put your data into buckets based on a column (usually the one you are joining on). Then when spark reads the data from the buckets again, you will not need to perform an exchange.

    1. Sample Data

    Transactions (Fact)

    t1.sample(n=5)
    
    ad_id     impressions
    30        528749
    1         552233
    30        24298
    30        311914
    60        41661
    
    

    Names (Dimension)

    t2.sample(n=5)
    
    ad_id     brand_name
    1         McDonalds
    30        McDonalds
    30        Coca-Cola
    1         Coca-Cola
    30        Levis
    

    2. Disable Broadcast Join

    Since one table is large and the other is small, you will need to disable broadcastJoin.

    sqlContext.sql("SET spark.sql.autoBroadcastJoinThreshold = -1")
    

    3. Without Bucketing

    t = spark.createDataFrame(t1)
    b = spark.createDataFrame(t2)
    
    
    t.write.saveAsTable('unbucketed_transactions')
    b.write.saveAsTable('unbucketed_brands')
    
    unbucketed_transactions = sqlContext.table("unbucketed_transactions")
    unbucketed_brands = sqlContext.table("unbucketed_brands")
    
    
    unbucketed_transactions.join(unbucketed_brands, 'ad_id').explain()
    
    
    
    +- Project [ad_id#1842L, impressions#1843L, brand_name#1847]
       +- SortMergeJoin [ad_id#1842L], [ad_id#1846L], Inner
          :- Sort [ad_id#1842L ASC NULLS FIRST], false, 0
          :  +- Exchange hashpartitioning(ad_id#1842L, 200), true, [id=#1336]     <-- 0_0
          :     +- Project [ad_id#1842L, impressions#1843L]
          :        +- Filter isnotnull(ad_id#1842L)
          :           +- FileScan parquet default.unbucketed_transactions
          +- Sort [ad_id#1846L ASC NULLS FIRST], false, 0
             +- Exchange hashpartitioning(ad_id#1846L, 200), true, [id=#1337]     <-- 0_0 
                +- Project [ad_id#1846L, brand_name#1847]
                   +- Filter isnotnull(ad_id#1846L)
                      +- FileScan parquet default.unbucketed_brands
    
    
    

    As you can see there is an exchange that takes place due to the unbucketed join.

    4. With Bucketing

    # The number 30 tells spark how large the buckets should be. 
    # The second parameter is what column the bucket should be based on.
    
    unbucketed_transactions.write \
    .bucketBy(30,'ad_id') \
    .sortBy('ad_id') \
    .saveAsTable('bucketed_transactions')
    
    
    unbucketed_brands.write \
    .bucketBy(30,'ad_id') \
    .sortBy('ad_id') \
    .saveAsTable('bucketed_brands')
    
    transactions = sqlContext.table("bucketed_transactions")
    brands = sqlContext.table("bucketed_brands")
    
    transactions.join(brands, 'ad_id').explain()
    
    
    +- Project [ad_id#1867L, impressions#1868L, brand_name#1872]
       +- SortMergeJoin [ad_id#1867L], [ad_id#1871L], Inner
          :- Sort [ad_id#1867L ASC NULLS FIRST], false, 0
          :  +- Project [ad_id#1867L, impressions#1868L]
          :     +- Filter isnotnull(ad_id#1867L)
          :        +- FileScan parquet default.bucketed_transactions
          +- Sort [ad_id#1871L ASC NULLS FIRST], false, 0
             +- Project [ad_id#1871L, brand_name#1872]
                +- Filter isnotnull(ad_id#1871L)
                   +- FileScan parquet default.bucketed_brands
    
    

    As can be seen by the plan above, there are no more exchanges that take place. Thus, you will improve your performance, by avoiding the exchange.