I have a Spark dataframe that looks like this:
+---+-----------+-------------------------+---------------+
| id| Phase | Switch | InputFileName |
+---+-----------+-------------------------+---------------+
| 1| 2| 1| fileA|
| 2| 2| 1| fileA|
| 3| 2| 1| fileA|
| 4| 2| 0| fileA|
| 5| 2| 0| fileA|
| 6| 2| 1| fileA|
| 11| 2| 1| fileB|
| 12| 2| 1| fileB|
| 13| 2| 0| fileB|
| 14| 2| 0| fileB|
| 15| 2| 1| fileB|
| 16| 2| 1| fileB|
| 21| 4| 1| fileB|
| 22| 4| 1| fileB|
| 23| 4| 1| fileB|
| 24| 4| 1| fileB|
| 25| 4| 1| fileB|
| 26| 4| 0| fileB|
| 31| 1| 0| fileC|
| 32| 1| 0| fileC|
| 33| 1| 0| fileC|
| 34| 1| 0| fileC|
| 35| 1| 0| fileC|
| 36| 1| 0| fileC|
+---+-----------+-------------------------+---------------+
For each group (a combination of InputFileName
and Phase
) I need to run a validation function which checks that Switch
equals 1 at the very start and end of the group, and transitions to 0 at any point in-between. The function should add the validation result as a new column. The expected output is below: (gaps are just to highlight the different groups)
+---+-----------+-------------------------+---------------+--------+
| id| Phase | Switch | InputFileName | Valid |
+---+-----------+-------------------------+---------------+--------+
| 1| 2| 1| fileA| true |
| 2| 2| 1| fileA| true |
| 3| 2| 1| fileA| true |
| 4| 2| 0| fileA| true |
| 5| 2| 0| fileA| true |
| 6| 2| 1| fileA| true |
| 11| 2| 1| fileB| true |
| 12| 2| 1| fileB| true |
| 13| 2| 0| fileB| true |
| 14| 2| 0| fileB| true |
| 15| 2| 1| fileB| true |
| 16| 2| 1| fileB| true |
| 21| 4| 1| fileB| false|
| 22| 4| 1| fileB| false|
| 23| 4| 1| fileB| false|
| 24| 4| 1| fileB| false|
| 25| 4| 1| fileB| false|
| 26| 4| 0| fileB| false|
| 31| 1| 0| fileC| false|
| 32| 1| 0| fileC| false|
| 33| 1| 0| fileC| false|
| 34| 1| 0| fileC| false|
| 35| 1| 0| fileC| false|
| 36| 1| 0| fileC| false|
+---+-----------+-------------------------+---------------+--------+
I have previously solved this using Pyspark and a Pandas UDF:
df = df.groupBy("InputFileName", "Phase").apply(validate_profile)
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def validate_profile(df: pd.DataFrame):
first_valid = True if df["Switch"].iloc[0] == 1 else False
during_valid = (df["Switch"].iloc[1:-1] == 0).any()
last_valid = True if df["Switch"].iloc[-1] == 1 else False
df["Valid"] = first_valid & during_valid & last_valid
return df
However, now I need to rewrite this in Scala. I just want to know the best way of accomplishing this.
I'm currently trying window functions to get the first and last ids of each group:
val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id")
val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc)
I can then add the min and max ids as separate columns and use when
to get the start and end values of Switch
:
df.withColumn("MinId", min("id").over(minIdWindow))
.withColumn("MaxId", max("id").over(maxIdWindow))
.withColumn("Valid", when(
col("id") === col("MinId"), col("Switch")
).when(
col("id") === col("MaxId"), col("Switch")
))
This gets me the start and end values, but I'm not sure how to check if Switch
equals 0 in between. Am I on the right track using window functions? Or would you recommend an alternative solution?
Try this,
val wind = Window.partitionBy("InputFileName", "Phase").orderBy("id")
.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
val df1 = df.withColumn("Valid",
when(first("Switch").over(wind) === 1
&& last("Switch").over(wind) === 1
&& min("Switch").over(wind) === 0, true)
.otherwise(false))
df1.orderBy("id").show() //Ordering for display purpose
Output:
+---+-----+------+-------------+-----+
| id|Phase|Switch|InputFileName|Valid|
+---+-----+------+-------------+-----+
| 1| 2| 1| fileA| true|
| 2| 2| 1| fileA| true|
| 3| 2| 1| fileA| true|
| 4| 2| 0| fileA| true|
| 5| 2| 0| fileA| true|
| 6| 2| 1| fileA| true|
| 11| 2| 1| fileB| true|
| 12| 2| 1| fileB| true|
| 13| 2| 0| fileB| true|
| 14| 2| 0| fileB| true|
| 15| 2| 1| fileB| true|
| 16| 2| 1| fileB| true|
| 21| 4| 1| fileB|false|
| 22| 4| 1| fileB|false|
| 23| 4| 1| fileB|false|
| 24| 4| 1| fileB|false|
| 25| 4| 1| fileB|false|
| 26| 4| 0| fileB|false|
| 31| 1| 0| fileC|false|
| 32| 1| 0| fileC|false|
+---+-----+------+-------------+-----+