Search code examples
pythonpyspark

Efficient joins in several dataframes in PySpark


I have several dataframes that look like this:

name_df:
| id | name |
| -- | ---- |
| 1  | Mark |
| 2  | Lisa |
| 2  | Josh |

age_df:
| name  | age |
| --    | --- |
| Mark  | 20  |
| John  | 25  |
| Lisa  | 35  |


prescription_df
| name  | prescription |
| --    | ------------ |
| Lisa  | True         |
| Mark  | False        |

So what I am trying to guess here is if there are names that have an age above 30 and have a prescription set to True. Finally I have to return true or false if these conditions are met.

What I am currently doing is 2 different join statements like this:

df_1 = name_df.alias('main').join(
    age_df.alias('a'),
    (col('main.name') == col('a.name')) & col('a.age') > '30'
)
df_2 = df_1 alias('main').join(
    prescription_df.alias('b'),
    col('main.name') == col('b.name')) & col('b.prescription') == True
)

And I return true or false in case if it's empty or not:

return False if df_2.isEmpty() else True

I want to know if there is a 'better' way to do this, or a more performant way to do it.


Solution

  • By design PySpark only executes any of your code when you perform an operation that requires data to be collected. Before doing this however it will analyze all of your prior commands and create an optimized execution plan before pulling data. You can validate what's actually going to happen by using .explain()

    name_df.alias('main').join(
        age_df.alias('a'),
        (col('main.name') == col('a.name'))
    ).where(col('a.age') > 30).explain()
    

    and

    name_df.join(age_df.where(age_df.age > 30), on="name").explain()
    

    and

    name_df.join(age_df, on="name").where(age_df.age > 30).explain()
    

    all end up executing the same exact code at runtime

    the explaination shows you the physical plan

    == Physical Plan ==
    AdaptiveSparkPlan isFinalPlan=false
    +- SortMergeJoin [name#613], [name#643], Inner
       :- Sort [name#613 ASC NULLS FIRST], false, 0
       :  +- Exchange hashpartitioning(name#613, 200), ENSURE_REQUIREMENTS, [plan_id=2294]
       :     +- Filter isnotnull(name#613)
       :        +- Scan ExistingRDD[name#613,id#614L]
       +- Sort [name#643 ASC NULLS FIRST], false, 0
          +- Exchange hashpartitioning(name#643, 200), ENSURE_REQUIREMENTS, [plan_id=2295]
             +- Filter ((isnotnull(age#644L) AND isnotnull(name#643
    

    Out of personal preference I would likely write it like this, as it feels very intuitive to understand.

    prescription_df.where(prescription_df.prescription == True).join(
        age_df.where(age_df.age > 30), on="name"
    ).join(name_df, on="name").isEmpty()