Search code examples
sqlpostgresqlwindow-functionsrow-number

how to get moving window argmax in PostgreSQL


I'm trying to find the moving argmax of a column in a database using window-functions in PostgreSQL. Here's what I have so far:

select *,
(max(case when price = roll_max then (row_num) end) over (partition by roll_max order by s_date)) as argmax
from (
   select s_id, s_date, price, 
   row_number() over (partition by s_id order by s_date) as row_num,
   max(high_price) over (partition by s_id order by s_date rows 10 preceding) as roll_max
   from sample_table
) tb1
order by s_date

The above code is modified from this answer. I had to add partition by s_id because there are many different s_ids - table's unique key is: (s_id, s_date). So, I need the argmax for each pair over all available dates.

Here's the output I get for some sample output data (window size 10):

+-------+--------------+---------+---------+----------+------------------------------------------+
| s_id  |    s_date    |  price  | row_num | roll_max |                  argmax                  |
+-------+--------------+---------+---------+----------+------------------------------------------+
| "ABC" | "2020-06-10" | 322.390 |       1 |  322.390 | 1                                        |
| "ABC" | "2020-06-11" | 312.150 |       2 |  322.390 | 1                                        |
| "ABC" | "2020-06-12" | 309.080 |       3 |  322.390 | 1                                        |
| "ABC" | "2020-06-15" | 308.280 |       4 |  322.390 | 1                                        |
| "ABC" | "2020-06-16" | 315.640 |       5 |  322.390 | 1                                        |
| "ABC" | "2020-06-17" | 314.390 |       6 |  322.390 | 1                                        |
| "ABC" | "2020-06-18" | 312.300 |       7 |  322.390 | 1                                        |
| "ABC" | "2020-06-19" | 314.380 |       8 |  322.390 | 1                                        |
| "ABC" | "2020-06-22" | 311.050 |       9 |  322.390 | 1                                        |
| "ABC" | "2020-06-23" | 314.500 |      10 |  322.390 | 1                                        |
| "ABC" | "2020-06-24" | 310.510 |      11 |  322.390 | 1                                        |
| "ABC" | "2020-06-25" | 307.640 |      12 |  315.640 | NULL /* how to get row_num (5) here? */  |
| "ABC" | "2020-06-26" | 306.390 |      13 |  315.640 | NULL /* how to get row_num (5) here? */  |
| "ABC" | "2020-06-29" | 304.610 |      14 |  315.640 | NULL /* how to get row_num (5) here? */  |
| "ABC" | "2020-06-30" | 310.200 |      15 |  315.640 | NULL /* how to get row_num (5) here? */  |
| "ABC" | "2020-07-01" | 311.890 |      16 |  314.500 | NULL /* how to get row_num (10) here? */ |
| "ABC" | "2020-07-02" | 315.700 |      17 |  315.700 | 17                                       |
| "ABC" | "2020-07-06" | 317.680 |      18 |  317.680 | 18                                       |
+-------+--------------+---------+---------+----------+------------------------------------------+

I understand that the query I've written above only matches the current row with the max and if it matches, returns the row number - but this case is not always applicable as shown in the table above where 315.640 is the rolling max till (and including) row 12 but that value is from the previous window and not the current row.

My questions is: How can I get the value 5 in place of NULL in the above example - i.e, get the row_num of the actual argmax (315.640's row_num is 5) for every value for argmax - row_num can be for table or for each window (in this example window size is 10).

I've looked at other similar questions, but still could not get the result I want because what I'm trying to do is a rolling argmax and not over the entire column of the table.

Can anyone suggest a solution for this? I'm also open to using UDFs. I only have a basic knowledge of aggregate UDFs and so my approach of using a temporary array to hold the last 10 values and taking the max of it did not seem very efficient (not even sure if I do array functions like that) and I'm out of ideas at this point :/


Solution

  • Although a bit hard to read, you could do the following:

    1. Put all values for price inside this window into an array;
    2. Use array_position to find the value of the rolling max price;
    3. Adjust for row_number() by adding row_number() - 10 (the window size) to the output;
    4. Adjust for the start of the array by using GREATEST(row_number() - 10, 0) to prevent negative numbers:
    WITH sample_table(s_id, s_date, price) AS (
        VALUES ('ABC', '2020-06-10'::date, 322.390),
               ('ABC', '2020-06-11'::date, 312.150),
               ('ABC', '2020-06-12'::date, 309.080),
               ('ABC', '2020-06-15'::date, 308.280),
               ('ABC', '2020-06-16'::date, 315.640),
               ('ABC', '2020-06-17'::date, 314.390),
               ('ABC', '2020-06-18'::date, 312.300),
               ('ABC', '2020-06-19'::date, 314.380),
               ('ABC', '2020-06-22'::date, 311.050),
               ('ABC', '2020-06-23'::date, 314.500),
               ('ABC', '2020-06-24'::date, 310.510),
               ('ABC', '2020-06-25'::date, 307.640),
               ('ABC', '2020-06-26'::date, 306.390),
               ('ABC', '2020-06-29'::date, 304.610),
               ('ABC', '2020-06-30'::date, 310.200),
               ('ABC', '2020-07-01'::date, 311.890),
               ('ABC', '2020-07-02'::date, 315.700),
               ('ABC', '2020-07-06'::date, 317.680)
    )
    SELECT s_id,
           s_date,
           price,
           row_number() over (PARTITION BY s_id ORDER BY s_date),
           max(price) over (partition by s_id order by s_date rows 10 preceding) as roll_max,
           GREATEST(row_number() over (PARTITION BY s_id ORDER BY s_date) - 10, 0)
               + array_position(
                           array_agg(price) over (partition by s_id order by s_date rows 10 preceding),
                           max(price) over (partition by s_id order by s_date rows 10 preceding)
               ) as argmax
    FROM sample_table
    

    or, with a subquery, but easier to read:

    WITH sample_table(s_id, s_date, price) AS (
        VALUES ('ABC', '2020-06-10'::date, 322.390),
               ('ABC', '2020-06-11'::date, 312.150),
               ('ABC', '2020-06-12'::date, 309.080),
               ('ABC', '2020-06-15'::date, 308.280),
               ('ABC', '2020-06-16'::date, 315.640),
               ('ABC', '2020-06-17'::date, 314.390),
               ('ABC', '2020-06-18'::date, 312.300),
               ('ABC', '2020-06-19'::date, 314.380),
               ('ABC', '2020-06-22'::date, 311.050),
               ('ABC', '2020-06-23'::date, 314.500),
               ('ABC', '2020-06-24'::date, 310.510),
               ('ABC', '2020-06-25'::date, 307.640),
               ('ABC', '2020-06-26'::date, 306.390),
               ('ABC', '2020-06-29'::date, 304.610),
               ('ABC', '2020-06-30'::date, 310.200),
               ('ABC', '2020-07-01'::date, 311.890),
               ('ABC', '2020-07-02'::date, 315.700),
               ('ABC', '2020-07-06'::date, 317.680)
    )
    SELECT s_id, s_date, price, row_number, roll_max,
           GREATEST(row_number - 10, 0)
               + array_position(
                   prices,
                   roll_max
               ) as argmax
    FROM (
             SELECT s_id,
                    s_date,
                    price,
                    row_number() over (PARTITION BY s_id ORDER BY s_date),
                    max(price) over (partition by s_id order by s_date rows 10 preceding)       as roll_max,
                    array_agg(price)
                    over (partition by s_id order by s_date rows 10 preceding)                  as prices
             FROM sample_table
         ) as s