Search code examples
scalaapache-sparkapache-spark-dataset

Avoiding duplicate coulmns after nullSafeJoin scala spark


I have use case wherein I need to join nullable columns. I am doing the same like this :

  def nullSafeJoin(leftDF: DataFrame, rightDF: DataFrame, joinOnColumns: Seq[String]) = {

    val dataset1 = leftDF.alias("dataset1")
    val dataset2 = rightDF.alias("dataset2")

    val firstColumn = joinOnColumns.head
    val colExpression: Column = (col(s"dataset1.$firstColumn").eqNullSafe(col(s"dataset2.$firstColumn")))

    val fullExpr = joinOnColumns.tail.foldLeft(colExpression) {
      (colExpression, p) => colExpression && (col(s"dataset1.$p").eqNullSafe(col(s"dataset2.$p")))
    }
    dataset1.join(dataset2, fullExpr)
  }

The final joined dataset has duplicate columns. I have tried dropping the columns using the alias like this :

dataset1.join(dataset2, fullExpr).drop(s"dataset2.$firstColumn")

but it doesn't work.
I understand that instead of dropping we can do a select columns.

I am trying to have a generic code base so don't want to pass the list of columns to be selected to the function (In case of drop I would be having to just drop the list of joinOnColumns we have passed to the function)

Any pointers on how to solve this would be really helpful. Thanks!

Edit : (Sample data )

leftDF :
+------------------+-----------+---------+---------+-------+
|                 A|          B|        C|        D| status|
+------------------+-----------+---------+---------+-------+
|             14567|         37|        1|     game|Enabled|
|             14567|       BASE|        1|      toy| Paused|
|             13478|       null|        5|     game|Enabled|
|              2001|       BASE|        1|     null| Paused|
|              null|         37|        1|     home|Enabled|
+------------------+-----------+---------+---------+-------+

rightDF :
+------------------+-----------+---------+
|                 A|          B|        C|
+------------------+-----------+---------+
|               140|         37|        1|
|               569|       BASE|        1|
|             13478|       null|        5|
|              2001|       BASE|        1|
|              null|         37|        1|
+------------------+-----------+---------+

Final Join (Required):
+------------------+-----------+---------+---------+-------+
|                 A|          B|        C|        D| status|
+------------------+-----------+---------+---------+-------+
|             13478|       null|        5|     game|Enabled|
|              2001|       BASE|        1|     null| Paused|
|              null|         37|        1|     home|Enabled|
+------------------+-----------+---------+---------+-------+

Solution

  • Your final DataFrame has duplicate columns from both leftDF & rightDF, don't have identifier to check if that column is from leftDF or rightDF.

    So I have renamed leftDF & rightDF columns. leftDF columns starts with left_[column_name] & rightDF columns starts with right_[column_name]

    Hope below code will help you.

    scala> :paste
    // Entering paste mode (ctrl-D to finish)
    
      val left = Seq(("14567", "37", "1", "game", "Enabled"), ("14567", "BASE", "1", "toy", "Paused"), ("13478", "null", "5", "game", "Enabled"), ("2001", "BASE", "1", "null", "Paused"), ("null", "37", "1", "home", "Enabled")).toDF("a", "b", "c", "d", "status")
      val right = Seq(("140", "37", 1), ("569", "BASE", 1), ("13478", "null", 5), ("2001", "BASE", 1), ("null", "37", 1)).toDF("a", "b", "c")
    
      import org.apache.spark.sql.DataFrame
      def nullSafeJoin(leftDF: DataFrame, rightDF: DataFrame, joinOnColumns: Seq[String]):DataFrame = {
        val leftRenamedDF = leftDF
          .columns
          .map(c => (c, s"left_${c}"))
          .foldLeft(leftDF){ (df, c) =>
            df.withColumnRenamed(c._1, c._2)
          }
        val rightRenamedDF = rightDF
          .columns
          .map(c => (c, s"right_${c}"))
          .foldLeft(rightDF){(df, c) =>
            df.withColumnRenamed(c._1, c._2)
          }
    
        val fullExpr = joinOnColumns
          .tail
        .foldLeft($"left_${joinOnColumns.head}".eqNullSafe($"right_${joinOnColumns.head}")){(cee, p) =>
            cee && ($"left_${p}".eqNullSafe($"right_${p}"))
          }
    
        val finalColumns = joinOnColumns
          .map(c => col(s"left_${c}").as(c)) ++ // Taking All columns from Join columns
          leftDF.columns.diff(joinOnColumns).map(c => col(s"left_${c}").as(c)) ++ // Taking missing columns from leftDF
          rightDF.columns.diff(joinOnColumns).map(c => col(s"right_${c}").as(c)) // Taking missing columns from rightDF
    
        leftRenamedDF.join(rightRenamedDF, fullExpr).select(finalColumns: _*)
      }
    
    scala>
    
    

    Final DataFrame result is :

    scala> nullSafeJoin(left, right, Seq("a", "b", "c")).show(false)
    
    
    // Exiting paste mode, now interpreting.
    
    +-----+----+---+----+-------+
    |a    |b   |c  |d   |status |
    +-----+----+---+----+-------+
    |13478|null|5  |game|Enabled|
    |2001 |BASE|1  |null|Paused |
    |null |37  |1  |home|Enabled|
    +-----+----+---+----+-------+