Search code examples
pythonpython-polars

Grouping Rows in Polars


I'm trying to figure out how to aggregate rows into batches (ie. a sliding / tumbling window) such that if we had a table like

  id   x    y   
 ---- ---- ---- 
   1   x1   y1  
   2   x2   y2  
   3   x3   y3  
   4   x4   y4  
   5   x5   y5  

The result would be

  id   x    y                  Grouped x3                 
 ---- ---- ---- ----------------------------------------- 
   1   x1   y1   None                                     
   2   x2   y2   None                                     
   3   x3   y3   [[1, x1, y1], [2, x2, y2], [3, x3, y3]]  
   4   x4   y4   [[2, x2, y2], [3, x3, y3], [4, x4, y4]]  
   5   x5   y5   [[3, x3, y3], [4, x4, y4], [5, x5, y5]]  

I currently have an aggregated column by using concat_list such that:

  id   x    y       List      
 ---- ---- ---- ------------- 
   1   x1   y1   [1, x1, y1]  
   2   x2   y2   [2, x2, y2]  
   3   x3   y3   [3, x3, y3]  
   4   x4   y4   [4, x4, y4]  
   5   x5   y5   [5, x5, y5] 

I just don't know how to aggregate them across rows like the example above.

Thanks so much!


Solution

  • One way to perform this given your data is to use the rolling:

    Let's start with your data and create the list of concatenated integers.

    import polars as pl
    
    df = pl.DataFrame(
        {
            "id": [1, 2, 3, 4, 5],
            "x": [10, 20, 30, 40, 50],
            "y": [100, 200, 300, 400, 500],
        }
    ).with_columns(pl.concat_list(pl.all()).alias('cat_list'))
    df
    
    shape: (5, 4)
    ┌─────┬─────┬─────┬──────────────┐
    │ id  ┆ x   ┆ y   ┆ cat_list     │
    │ --- ┆ --- ┆ --- ┆ ---          │
    │ i64 ┆ i64 ┆ i64 ┆ list[i64]    │
    ╞═════╪═════╪═════╪══════════════╡
    │ 1   ┆ 10  ┆ 100 ┆ [1, 10, 100] │
    │ 2   ┆ 20  ┆ 200 ┆ [2, 20, 200] │
    │ 3   ┆ 30  ┆ 300 ┆ [3, 30, 300] │
    │ 4   ┆ 40  ┆ 400 ┆ [4, 40, 400] │
    │ 5   ┆ 50  ┆ 500 ┆ [5, 50, 500] │
    └─────┴─────┴─────┴──────────────┘
    

    From here, we can use a rolling and an index of 3i against your idcolumn. In the agg, if no summary expression is used, then items in the group are automatically concatenated into a list -- in this case, a list of lists.

    group_result = (
        df
        .rolling(index_column='id', period='3i', closed='right')
        .agg(pl.col('cat_list').alias('result'))
    )
    group_result
    
    shape: (5, 2)
    ┌─────┬─────────────────────────────────┐
    │ id  ┆ result                          │
    │ --- ┆ ---                             │
    │ i64 ┆ list[list[i64]]                 │
    ╞═════╪═════════════════════════════════╡
    │ 1   ┆ [[1, 10, 100]]                  │
    │ 2   ┆ [[1, 10, 100], [2, 20, 200]]    │
    │ 3   ┆ [[1, 10, 100], [2, 20, 200], [… │
    │ 4   ┆ [[2, 20, 200], [3, 30, 300], [… │
    │ 5   ┆ [[3, 30, 300], [4, 40, 400], [… │
    └─────┴─────────────────────────────────┘
    

    From this point, you can simply add the column to your existing dataset.

    df = df.select(
        pl.all(),
        group_result.get_column("result"),
    )
    df
    
    shape: (5, 5)
    ┌─────┬─────┬─────┬──────────────┬─────────────────────────────────┐
    │ id  ┆ x   ┆ y   ┆ cat_list     ┆ result                          │
    │ --- ┆ --- ┆ --- ┆ ---          ┆ ---                             │
    │ i64 ┆ i64 ┆ i64 ┆ list[i64]    ┆ list[list[i64]]                 │
    ╞═════╪═════╪═════╪══════════════╪═════════════════════════════════╡
    │ 1   ┆ 10  ┆ 100 ┆ [1, 10, 100] ┆ [[1, 10, 100]]                  │
    │ 2   ┆ 20  ┆ 200 ┆ [2, 20, 200] ┆ [[1, 10, 100], [2, 20, 200]]    │
    │ 3   ┆ 30  ┆ 300 ┆ [3, 30, 300] ┆ [[1, 10, 100], [2, 20, 200], [… │
    │ 4   ┆ 40  ┆ 400 ┆ [4, 40, 400] ┆ [[2, 20, 200], [3, 30, 300], [… │
    │ 5   ┆ 50  ┆ 500 ┆ [5, 50, 500] ┆ [[3, 30, 300], [4, 40, 400], [… │
    └─────┴─────┴─────┴──────────────┴─────────────────────────────────┘
    

    If you need to adhere to a minimum window size, you can use a when/then/otherwise and list.len to set values to None.

    df.with_columns(
        pl.when(pl.col('result').list.len() < 3)
        .then(None)
        .otherwise(pl.col('result'))
        .name.keep()
    )
    
    shape: (5, 5)
    ┌─────┬─────┬─────┬──────────────┬─────────────────────────────────┐
    │ id  ┆ x   ┆ y   ┆ cat_list     ┆ result                          │
    │ --- ┆ --- ┆ --- ┆ ---          ┆ ---                             │
    │ i64 ┆ i64 ┆ i64 ┆ list[i64]    ┆ list[list[i64]]                 │
    ╞═════╪═════╪═════╪══════════════╪═════════════════════════════════╡
    │ 1   ┆ 10  ┆ 100 ┆ [1, 10, 100] ┆ null                            │
    │ 2   ┆ 20  ┆ 200 ┆ [2, 20, 200] ┆ null                            │
    │ 3   ┆ 30  ┆ 300 ┆ [3, 30, 300] ┆ [[1, 10, 100], [2, 20, 200], [… │
    │ 4   ┆ 40  ┆ 400 ┆ [4, 40, 400] ┆ [[2, 20, 200], [3, 30, 300], [… │
    │ 5   ┆ 50  ┆ 500 ┆ [5, 50, 500] ┆ [[3, 30, 300], [4, 40, 400], [… │
    └─────┴─────┴─────┴──────────────┴─────────────────────────────────┘