Search code examples
apache-sparkpysparkwindow-functions

Window function and conditional filters in PySpark


Is there a way to conditionally apply filter to a window function in pyspark? For every group in col1 I want to keep only rows that have X in col2. If a group doesn't have X in col2 I want to keep all rows in that group.

+------+------+
| col1 | col2 |
+------+------+
| A    |      |
+------+------+
| A    | X    |
+------+------+
| A    |      |
+------+------+
| B    |      |
+------+------+
| B    |      |
+------+------+
| B    |      |
+------+------+

Solution

  • You can do this with a max window function to denote the group (partitioned by col1) which has 'X' in col2 with an identifier (1 in this case). Groups which don't have 'X' will get assigned null. Thereafter just filter the intermediate dataframe to get the desired result.

    from pyspark.sql import Window
    from pyspark.sql.functions import max,when
    w = Window.partitionBy(df.col1)
    df_1 = df.withColumn('x_exists',max(when(df.col2 == 'X',1)).over(w))
    df_2 = df_1.filter(((df_1.x_exists == 1) & (df_1.col2 == 'X')) | df_1.x_exists.isNull())
    df_2.show()