Search code examples
sqlpysparkpartitiondays

Return rows with last updated date for different days


Assume this is my PySpark dataframe, ordered by ("ID", "updated_at"):

ID updated_at stock_date row_num
a1 2024-03-25T20:52:36 2024-03-25 1
a1 2024-03-26T11:23:48 2024-03-26 2
a1 2024-03-26T19:25:10 2024-03-26 3
b2 2024-03-24T14:12:20 2024-03-24 4
b2 2024-03-24T20:58:52 2024-03-24 5
b2 2024-03-26T22:24:14 2024-03-26 6
b2 2024-03-28T20:38:38 2024-03-28 7
c3 2024-03-28T15:15:51 2024-03-28 8
c3 2024-03-28T18:11:50 2024-03-28 9
d4 2024-03-24T12:10:15 2024-03-24 10
d4 2024-03-26T21:20:15 2024-03-26 11
d4 2024-03-28T11:55:23 2024-03-28 12
d4 2024-03-28T22:40:34 2024-03-28 13
d4 2024-03-29T11:57:20 2024-03-29 14
d4 2024-03-29T21:48:19 2024-03-29 15

I want to return all rows with the most updated ID of the day, over multiple days.

In other words, if the same ID appears several times in the same day, it must only returns the last time it was updated. If it appears on different days, bring all the occurrences. Therefore, it would just be the rows: 1, 3, 5, 6, 7, 9, 10, 11, 13 and 15.

I've already created a window function for the most updated date, but I cant figure out how can I separate different days. My table always ends returning the most updated DAY, only.

How can I achieve that?


Solution

  • I have one solution, i did it in two steps:

    First i am using windows function with paritions by "ID", "stock_date" to get row_number within partition

    In second step i am grouping df from first step to get max value for given partition

    from pyspark.sql.window import Window
    from pyspark.sql.functions import row_number, max, col, first
    
    data = [
        ("a1", "2024-03-25T20:52:36", "2024-03-25", 1),
        ("a1", "2024-03-26T11:23:48", "2024-03-26", 2),
        ("a1", "2024-03-26T19:25:10", "2024-03-26", 3),
        ("b2", "2024-03-24T14:12:20", "2024-03-24", 4),
        ("b2", "2024-03-24T20:58:52", "2024-03-24", 5),
        ("b2", "2024-03-26T22:24:14", "2024-03-26", 6),
        ("b2", "2024-03-28T20:38:38", "2024-03-28", 7),
        ("c3", "2024-03-28T15:15:51", "2024-03-28", 8),
        ("c3", "2024-03-28T18:11:50", "2024-03-28", 9),
        ("d4", "2024-03-24T12:10:15", "2024-03-24", 10),
        ("d4", "2024-03-26T21:20:15", "2024-03-26", 11),
        ("d4", "2024-03-28T11:55:23", "2024-03-28", 12),
        ("d4", "2024-03-28T22:40:34", "2024-03-28", 13),
        ("d4", "2024-03-29T11:57:20", "2024-03-29", 14),
        ("d4", "2024-03-29T21:48:19", "2024-03-29", 15)
    ]
    
    columns = ['ID', 'updated_at', 'stock_date', 'row_num']
    df = spark.createDataFrame(data, columns)
    
    windowSpec  = Window.partitionBy("ID", "stock_date").orderBy("stock_date")
    groupedDf = df.withColumn("row_number", row_number().over(windowSpec))
    
    groupedDf.groupBy("ID", "stock_date").agg(
        max('updated_at').alias('updated_at'),
        max('row_num').alias('row_num'),
        max("row_number")
    ).drop("max(row_number)").show()
    

    output:

    +---+----------+-------------------+-------+
    | ID|stock_date|         updated_at|row_num|
    +---+----------+-------------------+-------+
    | a1|2024-03-25|2024-03-25T20:52:36|      1|
    | a1|2024-03-26|2024-03-26T19:25:10|      3|
    | b2|2024-03-24|2024-03-24T20:58:52|      5|
    | b2|2024-03-26|2024-03-26T22:24:14|      6|
    | b2|2024-03-28|2024-03-28T20:38:38|      7|
    | c3|2024-03-28|2024-03-28T18:11:50|      9|
    | d4|2024-03-24|2024-03-24T12:10:15|     10|
    | d4|2024-03-26|2024-03-26T21:20:15|     11|
    | d4|2024-03-28|2024-03-28T22:40:34|     13|
    | d4|2024-03-29|2024-03-29T21:48:19|     15|
    +---+----------+-------------------+-------+