Search code examples
scalaapache-sparkapache-spark-sqlinner-join

Scala Spark Join Dataframe in loop


I am trying to join DataFrames on the fly in loop. I am using a properties file to get the column details to use in the final data frame. Properties file -

a01=status:single,perm_id:multi
a02=status:single,actv_id:multi
a03=status:single,perm_id:multi,actv_id:multi
............................
............................

For each row in the properties file, I need to create a DataFrame and save it in a file. Loading the properties file using PropertiesReader. if the mode is single then I need to get only the column value from the table. But if multi, then I need to get the list of values.

val propertyColumn = properties.get("a01") //a01 value we are getting as an argument. This might be a01,a02 or a0n
val columns = propertyColumn.toString.split(",").map(_.toString)

act_det table -

+-------+--------+-----------+-----------+-----------+------------+
|id     |act_id  |status     |perm_id    |actv_id    | debt_id    |
+-------+--------+-----------+-----------+-----------+------------+
| 1     |1       |   4       | 1         | 10        | 1          |
+-------+--------+-----------+-----------+-----------+------------+
| 2     |1       |   4       | 2         | 20        | 2          |
+-------+--------+-----------+-----------+-----------+------------+
| 3     |1       |   4       | 3         | 30        | 1          |
+-------+--------+-----------+-----------+-----------+------------+
| 4     |2       |   4       | 5         | 10        | 3          |
+-------+--------+-----------+-----------+-----------+------------+
| 5     |2       |   4       | 6         | 20        | 1          |
+-------+--------+-----------+-----------+-----------+------------+
| 6     |2       |   4       | 7         | 30        | 1          |
+-------+--------+-----------+-----------+-----------+------------+
| 7     |3       |   4       | 1         | 10        | 3          |
+-------+--------+-----------+-----------+-----------+------------+
| 8     |3       |   4       | 5         | 20        | 1          |
+-------+--------+-----------+-----------+-----------+------------+
| 9     |3       |   4       | 2         | 30        | 3          |
+-------+--------+-----------+-----------+------------+-----------+

Main DataFrame -

val data = sqlContext.sql("select * from act_det")

I want the following output -

For a01 -

+-------+--------+-----------+
|act_id |status  |perm_id    |
+-------+--------+-----------+
|     1 |   4    | [1,2,3]   |
+-------+--------+-----------+
|     2 |   4    |  [5,6,7]  |
+-------+--------+-----------+
|     3 |   4    |  [1,5,2]  |
+-------+--------+-----------+

For a02 -

    +-------+--------+-----------+
    |act_id |status  |actv_id    |
    +-------+--------+-----------+
    |     1 |   4    | [10,20,30]|
    +-------+--------+-----------+
    |     2 |   4    | [10,20,30]|
    +-------+--------+-----------+
    |     3 |   4    | [10,20,30]|
    +-------+--------+-----------+

For a03 -

    +-------+--------+-----------+-----------+
    |act_id |status  |perm_id    |actv_id    |
    +-------+--------+-----------+-----------+
    |     1 |   4    | [1,2,3]   |[10,20,30] |
    +-------+--------+-----------+-----------+
    |     2 |   4    |  [5,6,7]  |[10,20,30] |
    +-------+--------+-----------+-----------+
    |     3 |   4    |  [1,5,2]  |[10,20,30] |
    +-------+--------+-----------+-----------+

But the data frame creation process should be dynamic.

I have tried below code but I am not able to implement the join logic for the DataFrames in loop.

val finalDF:DataFrame = ??? //empty dataframe
    for {
        column <- columns
    } yeild {
        val eachColumn = column.toString.split(":").map(_.toString)
        val columnName = eachColumn(0)
        val mode = eachColumn(1)
        if(mode.equalsIgnoreCase("single")) {
            data.select($"act_id", $"status").distinct
            //I want to join finalDF with data.select($"act_id", $"status").distinct
        } else if(mode.equalsIgnoreCase("multi")) {
            data.groupBy($"act_id").agg(collect_list($"perm_id").as("perm_id"))
            //I want to join finalDF with data.groupBy($"act_id").agg(collect_list($"perm_id").as("perm_id"))
        }
    }

Any advice or guidance would be greatly appreciated.


Solution

  • Check below code.

    scala> df.show(false)
    +---+------+------+-------+-------+-------+
    |id |act_id|status|perm_id|actv_id|debt_id|
    +---+------+------+-------+-------+-------+
    |1  |1     |4     |1      |10     |1      |
    |2  |1     |4     |2      |20     |2      |
    |3  |1     |4     |3      |30     |1      |
    |4  |2     |4     |5      |10     |3      |
    |5  |2     |4     |6      |20     |1      |
    |6  |2     |4     |7      |30     |1      |
    |7  |3     |4     |1      |10     |3      |
    |8  |3     |4     |5      |20     |1      |
    |9  |3     |4     |2      |30     |3      |
    +---+------+------+-------+-------+-------+
    

    Defining primary keys

    scala> val primary_key = Seq("act_id").map(col(_))
    primary_key: Seq[org.apache.spark.sql.Column] = List(act_id)
    

    Configs

    scala> configs.foreach(println)
    /*
    (a01,status:single,perm_id:multi)
    (a02,status:single,actv_id:multi)
    (a03,status:single,perm_id:multi,actv_id:multi)
    */
    
    

    Constructing Expression.

    scala> 
    val columns = configs
                    .map(c => {
                        c._2
                        .split(",")
                        .map(c => {
                                val cc = c.split(":"); 
                                if(cc.tail.contains("single")) 
                                    first(col(cc.head)).as(cc.head) 
                                else 
                                    collect_list(col(cc.head)).as(cc.head)
                            }
                        )
                    })
    
    /*
    columns: scala.collection.immutable.Iterable[Array[org.apache.spark.sql.Column]] = List(
        Array(first(status, false) AS `status`, collect_list(perm_id) AS `perm_id`), 
        Array(first(status, false) AS `status`, collect_list(actv_id) AS `actv_id`), 
        Array(first(status, false) AS `status`, collect_list(perm_id) AS `perm_id`, collect_list(actv_id) AS `actv_id`)
    )
    */
    
    

    Final Result

    scala> columns.map(c => df.groupBy(primary_key:_*).agg(c.head,c.tail:_*)).map(_.show(false))
    +------+------+---------+
    |act_id|status|perm_id  |
    +------+------+---------+
    |3     |4     |[1, 5, 2]|
    |1     |4     |[1, 2, 3]|
    |2     |4     |[5, 6, 7]|
    +------+------+---------+
    
    +------+------+------------+
    |act_id|status|actv_id     |
    +------+------+------------+
    |3     |4     |[10, 20, 30]|
    |1     |4     |[10, 20, 30]|
    |2     |4     |[10, 20, 30]|
    +------+------+------------+
    
    +------+------+---------+------------+
    |act_id|status|perm_id  |actv_id     |
    +------+------+---------+------------+
    |3     |4     |[1, 5, 2]|[10, 20, 30]|
    |1     |4     |[1, 2, 3]|[10, 20, 30]|
    |2     |4     |[5, 6, 7]|[10, 20, 30]|
    +------+------+---------+------------+