Search code examples
pythondataframeapache-sparkjoinpyspark

pyspark - join with OR condition


I would like to join two pyspark dataframes if at least one of two conditions is satisfied.

Toy data:

df1 = spark.createDataFrame([
    (10, 1, 666),
    (20, 2, 777),
    (30, 1, 888),
    (40, 3, 999),
    (50, 1, 111),
    (60, 2, 222),
    (10, 4, 333),
    (50, None, 444),
    (10, 0, 555),
    (50, 0, 666)
    ],
    ['var1', 'var2', 'other_var'] 
)

df2 = spark.createDataFrame([
    (10, 1),
    (20, 2),
    (30, None),
    (30, 0)
    ],
    ['var1_', 'var2_'] 
)

I would like to maintain all those rows of df1 where var1 is present in the distinct values of df2.var1_ OR var2 is present in the distinct values of df2.var2_ (but not in the case where such value is 0).

So, the expected output would be

+----+----+---------+-----+-----+
|var1|var2|other_var|var1_|var2_|
+----+----+---------+-----+-----+
|  10|   1|      666|   10|    1|   # join on both var1 and var2
|  20|   2|      777|   20|    2|   # join on both var1 and var2
|  30|   1|      888|   10|    1|   # join on both var1 and var2
|  50|   1|      111|   10|    1|   # join on var2
|  60|   2|      222|   20|    2|   # join on var2
|  10|   4|      333|   10|    1|   # join on var1
|  10|   0|      555|   10|    1|   # join on var1
+----+----+---------+-----+-----+

Among the other attempts, I tried

cond = [(df1.var1 == (df2.select('var1_').distinct()).var1_) | (df1.var2 == (df2.filter(F.col('var2_') != 0).select('var2_').distinct()).var2_)]
df1\
    .join(df2, how='inner', on=cond)\
    .show()

+----+----+---------+-----+-----+
|var1|var2|other_var|var1_|var2_|
+----+----+---------+-----+-----+
|  10|   1|      666|   10|    1|
|  20|   2|      777|   20|    2|
|  30|   1|      888|   10|    1|
|  50|   1|      111|   10|    1|
|  30|   1|      888|   30| null|
|  30|   1|      888|   30|    0|
|  60|   2|      222|   20|    2|
|  10|   4|      333|   10|    1|
|  10|   0|      555|   10|    1|
|  10|   0|      555|   30|    0|
|  50|   0|      666|   30|    0|
+----+----+---------+-----+-----+

but I obtained more rows than expected, and the rows where var2 == 0 were also preserved.

What am I doing wrong?

Note: I'm not using the .isin method because my actual df2 has around 20k rows and I've read here that this method with a large number of IDs could have a bad performance.


Solution

  • Try the condition below:

    cond = (df2.var2_ != 0) & ((df1.var1 == df2.var1_) | (df1.var2 == df2.var2_))
    df1\
        .join(df2, how='inner', on=cond)\
        .show()
    
    +----+----+---------+-----+-----+
    |var1|var2|other_var|var1_|var2_|
    +----+----+---------+-----+-----+
    |  10|   1|      666|   10|    1|
    |  30|   1|      888|   10|    1|
    |  20|   2|      777|   20|    2|
    |  50|   1|      111|   10|    1|
    |  60|   2|      222|   20|    2|
    |  10|   4|      333|   10|    1|
    |  10|   0|      555|   10|    1|
    +----+----+---------+-----+-----+
    

    The condition should only include the columns from the two dataframes to be joined. If you want to remove var2_ = 0, you can put them as a join condition, rather than as a filter.

    There is also no need to specify distinct, because it does not affect the equality condition, and also adds an unnecessary step.