I have a dataframe (Spark):
id value
3 0
3 1
3 0
4 1
4 0
4 0
I want to create a new dataframe:
3 0
3 1
4 1
I need to remove all the rows after 1 (value) for each id. I tried with window functions in Spark dataframe (Scala) but couldn't find a solution. It seems like I am going in a wrong direction.
I am looking for a solution in Scala.
Output using monotonically_increasing_id:
scala> val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data: org.apache.spark.sql.DataFrame = [id: int, value: int]
scala> val minIdx = dataWithIndex.filter($"value" === 1).groupBy($"id").agg(min($"idx")).toDF("r_id", "min_idx")
minIdx: org.apache.spark.sql.DataFrame = [r_id: int, min_idx: bigint]
scala> dataWithIndex.join(minIdx,($"r_id" === $"id") && ($"idx" <= $"min_idx")).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 4| 1|
+---+-----+
The solution won't work if we did a sorted transformation in the original dataframe. That time the monotonically_increasing_id() is generated based on original DF rather than sorted DF. I have missed that requirement before.
All suggestions are welcome.
One way is to use monotonically_increasing_id()
and a self-join:
val data = Seq((3,0),(3,1),(3,0),(4,1),(4,0),(4,0)).toDF("id", "value")
data.show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 3| 0|
| 4| 1|
| 4| 0|
| 4| 0|
+---+-----+
Now we generate a column named idx
with an increasing Long
:
val dataWithIndex = data.withColumn("idx", monotonically_increasing_id())
// dataWithIndex.cache()
Now we get the min(idx)
for each id
where value = 1
:
val minIdx = dataWithIndex
.filter($"value" === 1)
.groupBy($"id")
.agg(min($"idx"))
.toDF("r_id", "min_idx")
Now we join the min(idx)
back to the original DataFrame
:
dataWithIndex.join(
minIdx,
($"r_id" === $"id") && ($"idx" <= $"min_idx")
).select($"id", $"value").show
+---+-----+
| id|value|
+---+-----+
| 3| 0|
| 3| 1|
| 4| 1|
+---+-----+
Note: monotonically_increasing_id()
generates its value based on the partition of the row. This value may change each time dataWithIndex
is re-evaluated. In my code above, because of lazy evaluation, it's only when I call the final show
that monotonically_increasing_id()
is evaluated.
If you want to force the value to stay the same, for example so you can use show
to evaluate the above step-by-step, uncomment this line above:
// dataWithIndex.cache()