Search code examples
apache-sparkapache-spark-sqlwindow-functions

Check start, middle and end of groups in Spark


I have a Spark dataframe that looks like this:

+---+-----------+-------------------------+---------------+
| id| Phase     | Switch                  | InputFileName |
+---+-----------+-------------------------+---------------+
|  1|          2|                        1|          fileA|
|  2|          2|                        1|          fileA|
|  3|          2|                        1|          fileA|
|  4|          2|                        0|          fileA|
|  5|          2|                        0|          fileA|
|  6|          2|                        1|          fileA|
| 11|          2|                        1|          fileB|
| 12|          2|                        1|          fileB|
| 13|          2|                        0|          fileB|
| 14|          2|                        0|          fileB|
| 15|          2|                        1|          fileB|
| 16|          2|                        1|          fileB|
| 21|          4|                        1|          fileB|
| 22|          4|                        1|          fileB|
| 23|          4|                        1|          fileB|
| 24|          4|                        1|          fileB|
| 25|          4|                        1|          fileB|
| 26|          4|                        0|          fileB|
| 31|          1|                        0|          fileC|
| 32|          1|                        0|          fileC|
| 33|          1|                        0|          fileC|
| 34|          1|                        0|          fileC|
| 35|          1|                        0|          fileC|
| 36|          1|                        0|          fileC|
+---+-----------+-------------------------+---------------+

For each group (a combination of InputFileName and Phase) I need to run a validation function which checks that Switch equals 1 at the very start and end of the group, and transitions to 0 at any point in-between. The function should add the validation result as a new column. The expected output is below: (gaps are just to highlight the different groups)

+---+-----------+-------------------------+---------------+--------+
| id| Phase     | Switch                  | InputFileName | Valid  |
+---+-----------+-------------------------+---------------+--------+
|  1|          2|                        1|          fileA|   true |
|  2|          2|                        1|          fileA|   true |
|  3|          2|                        1|          fileA|   true |
|  4|          2|                        0|          fileA|   true |
|  5|          2|                        0|          fileA|   true |
|  6|          2|                        1|          fileA|   true |

| 11|          2|                        1|          fileB|   true |
| 12|          2|                        1|          fileB|   true |
| 13|          2|                        0|          fileB|   true |
| 14|          2|                        0|          fileB|   true |
| 15|          2|                        1|          fileB|   true |
| 16|          2|                        1|          fileB|   true |

| 21|          4|                        1|          fileB|   false|
| 22|          4|                        1|          fileB|   false|
| 23|          4|                        1|          fileB|   false|
| 24|          4|                        1|          fileB|   false|
| 25|          4|                        1|          fileB|   false|
| 26|          4|                        0|          fileB|   false|

| 31|          1|                        0|          fileC|   false|
| 32|          1|                        0|          fileC|   false|
| 33|          1|                        0|          fileC|   false|
| 34|          1|                        0|          fileC|   false|
| 35|          1|                        0|          fileC|   false|
| 36|          1|                        0|          fileC|   false|
+---+-----------+-------------------------+---------------+--------+

I have previously solved this using Pyspark and a Pandas UDF:

df = df.groupBy("InputFileName", "Phase").apply(validate_profile)

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def validate_profile(df: pd.DataFrame):
    first_valid = True if df["Switch"].iloc[0] == 1 else False
    during_valid = (df["Switch"].iloc[1:-1] == 0).any()
    last_valid = True if df["Switch"].iloc[-1] == 1 else False
    df["Valid"] = first_valid & during_valid & last_valid
    return df

However, now I need to rewrite this in Scala. I just want to know the best way of accomplishing this.

I'm currently trying window functions to get the first and last ids of each group:

val minIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy("id")
val maxIdWindow = Window.partitionBy("InputFileName", "Phase").orderBy(col("id").desc)

I can then add the min and max ids as separate columns and use when to get the start and end values of Switch:

df.withColumn("MinId", min("id").over(minIdWindow))
    .withColumn("MaxId", max("id").over(maxIdWindow))
    .withColumn("Valid", when(
        col("id") === col("MinId"), col("Switch")
    ).when(
        col("id") === col("MaxId"), col("Switch")
    ))

This gets me the start and end values, but I'm not sure how to check if Switch equals 0 in between. Am I on the right track using window functions? Or would you recommend an alternative solution?


Solution

  • Try this,

    val wind = Window.partitionBy("InputFileName", "Phase").orderBy("id")
      .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
    
    val df1 = df.withColumn("Valid", 
      when(first("Switch").over(wind) === 1 
        && last("Switch").over(wind) === 1 
        && min("Switch").over(wind) === 0, true)
        .otherwise(false))
    df1.orderBy("id").show() //Ordering for display purpose
    

    Output:

    +---+-----+------+-------------+-----+
    | id|Phase|Switch|InputFileName|Valid|
    +---+-----+------+-------------+-----+
    |  1|    2|     1|        fileA| true|
    |  2|    2|     1|        fileA| true|
    |  3|    2|     1|        fileA| true|
    |  4|    2|     0|        fileA| true|
    |  5|    2|     0|        fileA| true|
    |  6|    2|     1|        fileA| true|
    | 11|    2|     1|        fileB| true|
    | 12|    2|     1|        fileB| true|
    | 13|    2|     0|        fileB| true|
    | 14|    2|     0|        fileB| true|
    | 15|    2|     1|        fileB| true|
    | 16|    2|     1|        fileB| true|
    | 21|    4|     1|        fileB|false|
    | 22|    4|     1|        fileB|false|
    | 23|    4|     1|        fileB|false|
    | 24|    4|     1|        fileB|false|
    | 25|    4|     1|        fileB|false|
    | 26|    4|     0|        fileB|false|
    | 31|    1|     0|        fileC|false|
    | 32|    1|     0|        fileC|false|
    +---+-----+------+-------------+-----+