Search code examples
scalaapache-sparkspark-streamingetl

How to iterate grouped rows to produce multiple rows in spark structured streaming?


I have the input data set like:

id     operation          value
1      null                1
1      discard             0
2      null                1
2      null                2
2      max                 0
3      null                1
3      null                1
3      list                0

I want to group the input and produce rows according to "operation" column.

for group 1, operation="discard", then the output is null,

for group 2, operation="max", the output is:

2      null                2

for group 3, operation="list", the output is:

3      null                1
3      null                1

So finally the output is like:

  id     operation          value
   2      null                2
   3      null                1
   3      null                1

Is there a solution for this?

I know there is a similar question how-to-iterate-grouped-data-in-spark But the differences compared to that are:

    1. I want to produce more than one row for each grouped data. Possible and how?
    2. I want my logic to be easily extended for more operation to be added in future. So User-defined aggregate functions (aka UDAF) is the only possible solution?

Update 1:

Thank stack0114106, then more details according to his answer, e.g. for id=1, operation="max", I want to iterate all the item with id=2, and find the max value, rather than assign a hard-coded value, that's why I want to iterate the rows in each group. Below is a updated example:

The input:

scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id"
,"operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]

scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|0  |null     |1    |
|0  |discard  |0    |
|1  |null     |1    |
|1  |null     |2    |
|1  |max      |0    |
|2  |null     |1    |
|2  |null     |3    |
|2  |max      |0    |
|3  |null     |1    |
|3  |null     |1    |
|3  |list     |0    |
+---+---------+-----+

The expected output:

+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1  |null     |2    |
|2  |null     |3    |
|3  |null     |1    |
|3  |null     |1    |
+---+---------+-----+

Solution

  • You can use flatMap operation on the dataframe and generate required rows based on the conditions that you mentioned. Check this out

    scala> val df = Seq((1,null,1),(1,"discard",0),(2,null,1),(2,null,2),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
    df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
    
    scala> df.show(false)
    +---+---------+-----+
    |id |operation|value|
    +---+---------+-----+
    |1  |null     |1    |
    |1  |discard  |0    |
    |2  |null     |1    |
    |2  |null     |2    |
    |2  |max      |0    |
    |3  |null     |1    |
    |3  |null     |1    |
    |3  |list     |0    |
    +---+---------+-----+
    
    
    scala> df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0
     until s._1).map( i => (r.getInt(0),null,s._2) ) }).show(false)
    +---+----+---+
    |_1 |_2  |_3 |
    +---+----+---+
    |2  |null|2  |
    |3  |null|1  |
    |3  |null|1  |
    +---+----+---+
    

    Spark assigns _1,_2 etc.. so you can map them to actual names by assigning them as below

    scala> val df2 = df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0 until s._1).map( i => (r.getInt(0),null,s._2) ) }).toDF("id","operation","value")
    df2: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
    
    scala> df2.show(false)
    +---+---------+-----+
    |id |operation|value|
    +---+---------+-----+
    |2  |null     |2    |
    |3  |null     |1    |
    |3  |null     |1    |
    +---+---------+-----+
    
    
    scala>
    

    EDIT1:

    Since you need the max(value) for each id, you can use window functions and get the max value in a new column, then use the same technique and get the results. Check this out

    scala> val df =   Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
    df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
    
    scala> df.createOrReplaceTempView("michael")
    
    scala> val df2 = spark.sql(""" select *, max(value) over(partition by id) mx from michael """)
    df2: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 2 more fields]
    
    scala> df2.show(false)
    +---+---------+-----+---+
    |id |operation|value|mx |
    +---+---------+-----+---+
    |1  |null     |1    |2  |
    |1  |null     |2    |2  |
    |1  |max      |0    |2  |
    |3  |null     |1    |1  |
    |3  |null     |1    |1  |
    |3  |list     |0    |1  |
    |2  |null     |1    |3  |
    |2  |null     |3    |3  |
    |2  |max      |0    |3  |
    |0  |null     |1    |1  |
    |0  |discard  |0    |1  |
    +---+---------+-----+---+
    
    
    scala> val df3 = df2.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => 0 case "max" => 1 case "list" => 2 } ; (0 until s).map( i => (r.getInt(0),null,r.getInt(3) )) }).toDF("id","operation","value")
    df3: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
    
    
    scala> df3.show(false)
    +---+---------+-----+
    |id |operation|value|
    +---+---------+-----+
    |1  |null     |2    |
    |3  |null     |1    |
    |3  |null     |1    |
    |2  |null     |3    |
    +---+---------+-----+
    
    
    scala>