Search code examples
dataframeapache-sparkpysparkapache-spark-sql

Natural join for dataframes


SQL has a natural join operation but for dataframes there appears to be no equivalent. What's the simplest way to implement a natural join function in PySpark?

https://spark.apache.org/docs/latest/sql-ref-syntax-qry-select-join.html

Given:

x = spark.createDataFrame([(1,'a'),(2,'b'),(3,'c')], ['c1','c2'])
z = spark.createDataFrame([(1,'aaaaa'),(2,'bbbbb')], ['c1','c3'])

When:

naturalJoin(x,z) # ???

Expected Result:

+---+---+-----+
| c1| c2|   c3|
+---+---+-----+
|  1|  a|aaaaa|
|  2|  b|bbbbb|
+---+---+-----+

Solution

  • You can get the same output as a natural join would generate with the df.join method:

    from pyspark.sql import SparkSession
    
    spark = SparkSession.builder.getOrCreate()
    
    x = spark.createDataFrame([(1, "a"), (2, "b"), (3, "c")], ["c1", "c2"])
    z = spark.createDataFrame([(1, "aaaaa"), (2, "bbbbb")], ["c1", "c3"])
    

    If the second argument to df.join is a join expression, you'll get repeated join columns. This will duplicate the join columns, which is not what you want:

    >>> x.join(z, x.c1 == z.c1).show()
    +---+---+---+-----+
    | c1| c2| c1|   c3|
    +---+---+---+-----+
    |  1|  a|  1|aaaaa|
    |  2|  b|  2|bbbbb|
    +---+---+---+-----+
    

    But since you're trying to execute a natural join, you have 2 columns with the same name. In that case, the second argument to join can be just the column name instead of the join expression. That will generate the output you want:

    >>> x.join(z, "c1").show()
    +---+---+-----+                                                                 
    | c1| c2|   c3|
    +---+---+-----+
    |  1|  a|aaaaa|
    |  2|  b|bbbbb|
    +---+---+-----+
    

    If you don't want to specify the columns to join on, and automatically select the columns with the same name you'll have to define that function yourself since it does not exist in Pyspark. You can check this in the source code (I checked version 3.5.1, the latest version as of this post):

    def natural_join(df1: DataFrame, df2: DataFrame) -> DataFrame:
        common_columns = list(set(x.columns).intersection(set(z.columns)))
        return df1.join(df2, common_columns)
    
    
    natural_join(x, z).show()
    +---+---+-----+                                                                 
    | c1| c2|   c3|
    +---+---+-----+
    |  1|  a|aaaaa|
    |  2|  b|bbbbb|
    +---+---+-----+
    

    Using SQL

    You can execute a natural join in Spark using SQL though, but as you're asking about Pyspark specifically I'm only adding this as appendix. In the source code you can find some examples:

    SELECT * FROM nt1 natural join nt2;
    SELECT * FROM nt1 natural join nt2 where k = "one";
    SELECT * FROM nt1 natural left join nt2 order by v1, v2;
    SELECT * FROM nt1 natural right join nt2 order by v1, v2;
    SELECT count(*) FROM nt1 natural full outer join nt2;
    SELECT k FROM nt1 natural join nt2;
    SELECT k FROM nt1 natural join nt2 where k = "one";
    SELECT nt1.* FROM nt1 natural join nt2;
    SELECT nt2.* FROM nt1 natural join nt2;