I have the input data set like:
id operation value
1 null 1
1 discard 0
2 null 1
2 null 2
2 max 0
3 null 1
3 null 1
3 list 0
I want to group the input and produce rows according to "operation" column.
for group 1, operation="discard", then the output is null,
for group 2, operation="max", the output is:
2 null 2
for group 3, operation="list", the output is:
3 null 1
3 null 1
So finally the output is like:
id operation value
2 null 2
3 null 1
3 null 1
Is there a solution for this?
I know there is a similar question how-to-iterate-grouped-data-in-spark But the differences compared to that are:
Update 1:
Thank stack0114106, then more details according to his answer, e.g. for id=1, operation="max", I want to iterate all the item with id=2, and find the max value, rather than assign a hard-coded value, that's why I want to iterate the rows in each group. Below is a updated example:
The input:
scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id"
,"operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|0 |null |1 |
|0 |discard |0 |
|1 |null |1 |
|1 |null |2 |
|1 |max |0 |
|2 |null |1 |
|2 |null |3 |
|2 |max |0 |
|3 |null |1 |
|3 |null |1 |
|3 |list |0 |
+---+---------+-----+
The expected output:
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |2 |
|2 |null |3 |
|3 |null |1 |
|3 |null |1 |
+---+---------+-----+
You can use flatMap operation on the dataframe and generate required rows based on the conditions that you mentioned. Check this out
scala> val df = Seq((1,null,1),(1,"discard",0),(2,null,1),(2,null,2),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |1 |
|1 |discard |0 |
|2 |null |1 |
|2 |null |2 |
|2 |max |0 |
|3 |null |1 |
|3 |null |1 |
|3 |list |0 |
+---+---------+-----+
scala> df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0
until s._1).map( i => (r.getInt(0),null,s._2) ) }).show(false)
+---+----+---+
|_1 |_2 |_3 |
+---+----+---+
|2 |null|2 |
|3 |null|1 |
|3 |null|1 |
+---+----+---+
Spark assigns _1,_2 etc.. so you can map them to actual names by assigning them as below
scala> val df2 = df.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => (0,0) case "max" => (1,2) case "list" => (2,1) } ; (0 until s._1).map( i => (r.getInt(0),null,s._2) ) }).toDF("id","operation","value")
df2: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
scala> df2.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|2 |null |2 |
|3 |null |1 |
|3 |null |1 |
+---+---------+-----+
scala>
EDIT1:
Since you need the max(value) for each id, you can use window functions and get the max value in a new column, then use the same technique and get the results. Check this out
scala> val df = Seq((0,null,1),(0,"discard",0),(1,null,1),(1,null,2),(1,"max",0),(2,null,1),(2,null,3),(2,"max",0),(3,null,1),(3,null,1),(3,"list",0)).toDF("id","operation","value")
df: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 1 more field]
scala> df.createOrReplaceTempView("michael")
scala> val df2 = spark.sql(""" select *, max(value) over(partition by id) mx from michael """)
df2: org.apache.spark.sql.DataFrame = [id: int, operation: string ... 2 more fields]
scala> df2.show(false)
+---+---------+-----+---+
|id |operation|value|mx |
+---+---------+-----+---+
|1 |null |1 |2 |
|1 |null |2 |2 |
|1 |max |0 |2 |
|3 |null |1 |1 |
|3 |null |1 |1 |
|3 |list |0 |1 |
|2 |null |1 |3 |
|2 |null |3 |3 |
|2 |max |0 |3 |
|0 |null |1 |1 |
|0 |discard |0 |1 |
+---+---------+-----+---+
scala> val df3 = df2.filter("operation is not null").flatMap( r=> { val x=r.getString(1); val s = x match { case "discard" => 0 case "max" => 1 case "list" => 2 } ; (0 until s).map( i => (r.getInt(0),null,r.getInt(3) )) }).toDF("id","operation","value")
df3: org.apache.spark.sql.DataFrame = [id: int, operation: null ... 1 more field]
scala> df3.show(false)
+---+---------+-----+
|id |operation|value|
+---+---------+-----+
|1 |null |2 |
|3 |null |1 |
|3 |null |1 |
|2 |null |3 |
+---+---------+-----+
scala>