Search code examples
pythonpyspark

Per id, filter based on conditions and keep next row


I have a pyspark dataframe in the general form of:

+---+----------+---------+--------+
| id|      date| value   |id_index|
+---+----------+---------+--------+
|  1|2023-05-14|107.27966|       1|
|  1|2023-05-14| 78.23225|       2|
|  1|2023-05-14|226.91467|       3|
|  1|2023-05-15|107.27966|       1|
|  1|2023-05-15| 78.23225|       2|
|  1|2023-05-15|226.91467|       3|
|  2|2023-05-14|249.05295|       1|
|  2|2023-05-14|      2.0|       2|
|  2|2023-05-14|      0.0|       3|
|  2|2023-05-15|249.05295|       1|
|  2|2023-05-15|      2.0        2|
|  2|2023-05-15|      0.0|       3|
+---+----------+---------+--------+

Per unique id, for the earliest date, I want to keep the row with the lowest id_index. For the next date for that unique id, I'd like to keep the next id_index. For example:

+---+----------+---------+--------+
| id|      date| value   |id_index|
+---+----------+---------+--------+
|  1|2023-05-14|107.27966|       1|
|  1|2023-05-15| 78.23225|       2|
|   |       ...|      ...|     ...|
|  2|2023-05-14|249.05295|       1|
|  2|2023-05-15|      2.0|       2|
|   |       ...|      ...|     ...|
+---+----------+---------+--------+

This is the result of an explode execution transforming a column list into separate rows. I tried creating some additional conditional indices in the hopes that I could use dropDuplicates with no success.

This is the result of an explode execution transforming a column list into separate rows. I tried creating some additional conditional indices in the hopes that I could use dropDuplicates with no success.


Solution

  • Create a window specification and assign the group numbers to each unique combination of id and date then filter the dataframe where the group number is equal to the id_index.

    W = Window.partitionBy('id').orderBy('date')
    df1 = df.withColumn('ngroup', F.sum((F.col('id_index') == 1).cast('int')).over(W))
    df1 = df1.filter('id_index = ngroup')
    

    df1.show()
    +---+----------+---------+--------+------+
    | id|      date|    value|id_index|ngroup|
    +---+----------+---------+--------+------+
    |  1|2023-05-14|107.27966|       1|     1|
    |  1|2023-05-15| 78.23225|       2|     2|
    |  2|2023-05-14|249.05295|       1|     1|
    |  2|2023-05-15|      2.0|       2|     2|
    +---+----------+---------+--------+------+