Search code examples
pythonpandasgroup-byrolling-computation

How to compute rolling window, then groupby, followed by aggregation without looping?


I have a pandas dataframe updates as follows:

                        streamid,low,high
time
2023-01-10 16:07:36.264,979,1.07331,1.07344
2023-01-10 16:07:36.359,1009,1.07331,1.07338
2023-01-10 16:07:36.444,781,1.07329,1.07341
2023-01-10 16:07:36.464,979,1.07331,1.07344
2023-01-10 16:07:36.470,1191,1.07331,1.0734
2023-01-10 16:07:36.480,1191,1.07333,1.07342
2023-01-10 16:07:36.493,2,1.07332,1.07337
2023-01-10 16:07:36.493,1009,1.07332,1.07338
2023-01-10 16:07:36.494,979,1.07332,1.07345
2023-01-10 16:07:36.494,786,1.07325,1.07332
2023-01-10 16:07:36.494,141,1.07332,1.07337
2023-01-10 16:07:36.496,1263,1.07332,1.07339
2023-01-10 16:07:36.496,818,1.07331,1.07338
2023-01-10 16:07:36.497,786,1.07325,1.07333
2023-01-10 16:07:36.499,844,1.07331,1.07336
2023-01-10 16:07:36.499,1009,1.07332,1.07339
2023-01-10 16:07:36.501,1028,1.07333,1.07337
2023-01-10 16:07:36.503,141,1.07333,1.07338
2023-01-10 16:07:36.504,1009,1.07333,1.0734
2023-01-10 16:07:36.509,1009,1.07333,1.07341
2023-01-10 16:07:36.509,786,1.07327,1.07335

I want to compute a rolling max of low and a rolling min of high in a 5s window with a caveat: if in any window, there are multiple rows with the same streamid, only the latest row should be considered.

Conceptually, this should be easy: All I need do is to get rolling windows of 5s each, perform a groupby on streamid, call last() on the GroupBy object to get the correct row for each group, and then do agg({"low": "max", "high": "min"}) on the rolling window.

In practice, I found out that I cannot do groupby on a rolling window since rolling is applied on each column separately. I understand that I can use method='table' in rolling, and engine='numba', raw=True in apply to get entire dataframes in my custom function, but I am not able to do groupby inside the numba function.

Here is my looping solution that is insanely slow, but gives the right answer:

from itertools import islice
lows = []
highs = []
times = []
rolled = past_updates.rolling("5s")
for index, df in enumerate(islice(rolled, None)):
    low, high = df.groupby(["streamid"]).last().agg({"low": "max", "high": "min"})
    lows.append(low)
    highs.append(high)
    times.append(df.index[-1])
out_df = pd.DataFrame({
    "low": lows, "high": highs, "time": times
}).set_index("time")

This gives the result:

                        low,high
time
2023-01-10 16:07:36.264,1.07331,1.07344
2023-01-10 16:07:36.359,1.07331,1.07338
2023-01-10 16:07:36.444,1.07331,1.07338
2023-01-10 16:07:36.464,1.07331,1.07338
2023-01-10 16:07:36.470,1.07331,1.07338
2023-01-10 16:07:36.480,1.07333,1.07338
2023-01-10 16:07:36.493,1.07333,1.07337
2023-01-10 16:07:36.493,1.07333,1.07337
2023-01-10 16:07:36.494,1.07333,1.07337
2023-01-10 16:07:36.494,1.07333,1.07332
2023-01-10 16:07:36.494,1.07333,1.07332
2023-01-10 16:07:36.496,1.07333,1.07332
2023-01-10 16:07:36.496,1.07333,1.07332
2023-01-10 16:07:36.497,1.07333,1.07333  # <-- 1.07332 is dropped due to new update from 786
2023-01-10 16:07:36.499,1.07333,1.07333
2023-01-10 16:07:36.499,1.07333,1.07333
2023-01-10 16:07:36.501,1.07333,1.07333
2023-01-10 16:07:36.503,1.07333,1.07333
2023-01-10 16:07:36.504,1.07333,1.07333
2023-01-10 16:07:36.509,1.07333,1.07333
2023-01-10 16:07:36.509,1.07333,1.07335

Unfortunately, this takes 7-8 minutes for ~600k rows. I want to be able to do this many times for different timestamps. Is there a better way, potentially one that avoids loops?


Solution

  • After a lot of iteration, I got this to work utilizing numba engine. Here is my implementation for anyone in a similar situation:

    def group_and_agg(subarray):
        streamupdate_l = {}
        streamupdate_h = {}
        for row in subarray:
            if ~np.isnan(row[1]):
                streamupdate_l[row[0]] = row[1]
            if ~np.isnan(row[2]):
                streamupdate_h[row[0]] = row[2]
         
        lowmax = max(streamupdate_l.values())
        highmin = min(streamupdate_h.values())
        return highmin - lowmax, lowmax, highmin
    
    condensed = past_updates.rolling("5s", method="table").apply(group_and_agg,
                                                     engine="numba",
                                                     raw=True)
    

    While it still uses loops, it is much faster! For the same input of 600k rows, this takes 15s to run. One issue is that for some reason, the output needs to be of the same shape as the input, so I have to return one additional quantity even though I don't need it (highmin-lowmax in my case). Furthermore the returned dataframe condensed has the same column names as past_updates, so it needs to be renamed for clarity. Happily, the index names are correctly inferred from the rolling windows.