Search code examples
pythonpython-polars

cumulative aggregate a polars list[struct[]]


I have to accomplish a complex dataframe conversion like this:

original_dataframe = pl.DataFrame({'index': ['A', 'B', 'C', 'D', 'E', 'F', 'G'], 'content': [{'key': 3, 'val': 20}, {'key': 4, 'val': 50}, {'key': 3, 'val': 8}, {'key': 5, 'val': 70}, {'key': 4, 'val': -60}, {'key': 2, 'val': 30}, {'key': 4, 'val': 5}]})
┌───────┬───────────┐
│ index ┆ content   │
│ ---   ┆ ---       │
│ str   ┆ struct[2] │
╞═══════╪═══════════╡
│ A     ┆ {3,20}    │
│ B     ┆ {4,50}    │
│ C     ┆ {3,8}     │
│ D     ┆ {5,70}    │
│ E     ┆ {4,-60}   │
│ F     ┆ {2,30}    │
│ G     ┆ {4,5}     │
└───────┴───────────┘
       ||
       \/ 
┌───────┬──────────────────────────┐
│ index ┆ content                  │
│ ---   ┆ ---                      │
│ str   ┆ list[struct[2]]          │
╞═══════╪══════════════════════════╡
│ A     ┆ [{3,20}]                 │
│ B     ┆ [{3,20}, {4,50}]         │
│ C     ┆ [{3,28}, {4,50}]         │
│ D     ┆ [{3,28}, {4,50}, {5,70}] │
│ E     ┆ [{3,28}, {5,70}]         │
└───────┴──────────────────────────┘

This conversion combines:

  1. cumulative add struct into list row by row;
  2. if it exists same struct 'key' field in the list, aggregate the two struct by sum struct 'val' field;
  3. if the struct 'val' field <= 0 after aggregation, drop it in the list;
  4. sort each list by struct 'key' field;
  5. also drop struct if its 'val' field or 'key' field is null.

The conversion can be ugly down by use iter_rows() and to_list() to iterate dataframe rows with intermediate python data type list, dict. But this way is slow. How it can be solved just use polars functions for fast and elegant?

PS: Thanks @jqurious' reminder, there is an additional requirement, so I updated the question.

pl.DataFrame({'index': ['A', 'B', 'C', 'D', 'E', 'F'], 'content': [{'key': 3, 'val': 20}, {'key': 4, 'val': 50}, {'key': 3, 'val': 8}, {'key': 2, 'val': 30}, {'key': 4, 'val': -60}, {'key': 4, 'val': 5}]})

┌───────┬───────────┐
│ index ┆ content   │
│ ---   ┆ ---       │
│ str   ┆ struct[2] │
╞═══════╪═══════════╡
│ A     ┆ {3,20}    │
│ B     ┆ {4,50}    │
│ C     ┆ {3,8}     │
│ D     ┆ {2,30}    │
│ E     ┆ {4,-60}   │
│ F     ┆ {4,5}     │
└───────┴───────────┘
        ||
        \/ 
┌───────┬──────────────────────────┐
│ index ┆ content                  │
│ ---   ┆ ---                      │
│ str   ┆ list[struct[2]]          │
╞═══════╪══════════════════════════╡
│ A     ┆ [{3,20}]                 │
│ B     ┆ [{3,20}, {4,50}]         │
│ C     ┆ [{3,28}, {4,50}]         │
│ D     ┆ [{2,30}, {3,28}, {4,50}] │
│ E     ┆ [{2,30}, {3,28}]         │
│ F     ┆ [{2,30}, {3,28}, {4,5}]  │
└───────┴──────────────────────────┘

the updated requirement is:

  1. if the struct 'val' field <= 0 after cumulative sum, drop it at the corresponding row's list immediately; and if the struct 'key' field appears again in the following rows with struct 'val' field > 0, it should be cumulative aggregate again;

Solution

  • One possible 'pure polars' solution could be:

    (
        (
            df
            .unnest("content")
            .pivot(on="key", index="index", values="val")
            .with_columns(pl.all().exclude("index").fill_null(0).cum_sum())
        )
        .unpivot(index="index",variable_name="key",value_name="val")
        .filter(pl.col("val") > 0)
        .select(pl.col("index"), pl.struct("key","val"))
        .group_by("index", maintain_order=True).agg("key")
    )
    
    ┌───────┬────────────────────────────────┐
    │ index ┆ key                            │
    │ ---   ┆ ---                            │
    │ str   ┆ list[struct[2]]                │
    ╞═══════╪════════════════════════════════╡
    │ A     ┆ [{"3",20}]                     │
    │ B     ┆ [{"3",20}, {"4",50}]           │
    │ C     ┆ [{"3",28}, {"4",50}]           │
    │ D     ┆ [{"3",28}, {"4",50}, {"5",70}] │
    │ E     ┆ [{"3",28}, {"5",70}]           │
    └───────┴────────────────────────────────┘
    

    Explanation of transformations:

    1. We DataFrame.pivot() the DataFrame, to create columns out of all possible "key" values:
    (
        df
        .unnest("content")
        .pivot(on="key", index="index", values="val")
    )
    
    ┌───────┬──────┬──────┬──────┐
    │ index ┆ 3    ┆ 4    ┆ 5    │
    │ ---   ┆ ---  ┆ ---  ┆ ---  │
    │ str   ┆ i64  ┆ i64  ┆ i64  │
    ╞═══════╪══════╪══════╪══════╡
    │ A     ┆ 20   ┆ null ┆ null │
    │ B     ┆ null ┆ 50   ┆ null │
    │ C     ┆ 8    ┆ null ┆ null │
    │ D     ┆ null ┆ null ┆ 70   │
    │ E     ┆ null ┆ -60  ┆ null │
    └───────┴──────┴──────┴──────┘
    
    1. Now, we assign the final numbers for each "index" via Expr.fill_null() and Expr.cum_sum():
    (
        ...
        .with_columns(pl.all().exclude("index").fill_null(0).cum_sum())
    )
    
    ┌───────┬─────┬─────┬─────┐
    │ index ┆ 3   ┆ 4   ┆ 5   │
    │ ---   ┆ --- ┆ --- ┆ --- │
    │ str   ┆ i64 ┆ i64 ┆ i64 │
    ╞═══════╪═════╪═════╪═════╡
    │ A     ┆ 20  ┆ 0   ┆ 0   │
    │ B     ┆ 20  ┆ 50  ┆ 0   │
    │ C     ┆ 28  ┆ 50  ┆ 0   │
    │ D     ┆ 28  ┆ 50  ┆ 70  │
    │ E     ┆ 28  ┆ -10 ┆ 70  │
    └───────┴─────┴─────┴─────┘
    
    1. Now, use DataFrame.unpivot() to convert columns back to rows:
    (
        (
            ...
        )
        .unpivot(index="index",variable_name="key",value_name="val")
    )
    
    ┌───────┬─────┬─────┐
    │ index ┆ key ┆ val │
    │ ---   ┆ --- ┆ --- │
    │ str   ┆ str ┆ i64 │
    ╞═══════╪═════╪═════╡
    │ A     ┆ 3   ┆ 20  │
    │ B     ┆ 3   ┆ 20  │
    │ C     ┆ 3   ┆ 28  │
    │ D     ┆ 3   ┆ 28  │
    │ E     ┆ 3   ┆ 28  │
    │ …     ┆ …   ┆ …   │
    │ A     ┆ 5   ┆ 0   │
    │ B     ┆ 5   ┆ 0   │
    │ C     ┆ 5   ┆ 0   │
    │ D     ┆ 5   ┆ 70  │
    │ E     ┆ 5   ┆ 70  │
    └───────┴─────┴─────┘
    
    1. Now, kind of straightforward step - DataFrame.filter() greater than 0 values, convert "key" and "val" columns to struct() and then DataFrame.group_by() into result lists:
    (
        (
            ...
        )
        ...
        .filter(pl.col("val") > 0)
        .select(pl.col("index"), pl.struct("key","val"))
        .group_by("index", maintain_order=True).agg("key")
    )
    
    ┌───────┬────────────────────────────────┐
    │ index ┆ key                            │
    │ ---   ┆ ---                            │
    │ str   ┆ list[struct[2]]                │
    ╞═══════╪════════════════════════════════╡
    │ A     ┆ [{"3",20}]                     │
    │ B     ┆ [{"3",20}, {"4",50}]           │
    │ C     ┆ [{"3",28}, {"4",50}]           │
    │ D     ┆ [{"3",28}, {"4",50}, {"5",70}] │
    │ E     ┆ [{"3",28}, {"5",70}]           │
    └───────┴────────────────────────────────┘
    

    UPDATE

    As @jqurious pointed out in comments, the answer above is incomplete. Example input provided by OP didn't take into account that after some key went into negatives and was removed from the list, we should not use the "current" value, but restart with the next "positive" one.

    For example, if OP's input contained another row with F, {4, 5} then the answer should have [{3,28}, {4,5}, {5,70}].

    So, to adjust to this logic we need to add another step. The idea is to identify the rows where value become negative and subtract it from the cum_sum().

    2.1 Identify rows where value become negative, using when() and Expr.shift():

    values = pl.all().exclude("index")
    
    (
        ...
        .with_columns(
            pl.when(values <= 0, values.shift(1) > 0).then(values).otherwise(0)
        )
    )
    
    ┌───────┬─────┬─────┬─────┐
    │ index ┆ 3   ┆ 4   ┆ 5   │
    │ ---   ┆ --- ┆ --- ┆ --- │
    │ str   ┆ i64 ┆ i64 ┆ i64 │
    ╞═══════╪═════╪═════╪═════╡
    │ A     ┆ 0   ┆ 0   ┆ 0   │
    │ B     ┆ 0   ┆ 0   ┆ 0   │
    │ C     ┆ 0   ┆ 0   ┆ 0   │
    │ D     ┆ 0   ┆ 0   ┆ 0   │
    │ E     ┆ 0   ┆ -10 ┆ 0   │
    │ F     ┆ 0   ┆ 0   ┆ 0   │
    └───────┴─────┴─────┴─────┘
    

    2.2 Now, use cum_sum() to fill it forward and subtract it from our current values:

    (
        ...
        .with_columns(
            values - 
            pl.when(values <= 0, values.shift(1) > 0).then(values).otherwise(0).cum_sum()
        )
    )
    
    ┌───────┬─────┬─────┬─────┐
    │ index ┆ 3   ┆ 4   ┆ 5   │
    │ ---   ┆ --- ┆ --- ┆ --- │
    │ str   ┆ i64 ┆ i64 ┆ i64 │
    ╞═══════╪═════╪═════╪═════╡
    │ A     ┆ 20  ┆ 0   ┆ 0   │
    │ B     ┆ 20  ┆ 50  ┆ 0   │
    │ C     ┆ 28  ┆ 50  ┆ 0   │
    │ D     ┆ 28  ┆ 50  ┆ 70  │
    │ E     ┆ 28  ┆ 0   ┆ 70  │
    │ F     ┆ 28  ┆ 5   ┆ 70  │
    └───────┴─────┴─────┴─────┘
    

    So the final solution becomes:

    values = pl.all().exclude("index")
    
    (
        df
        .unnest("content")
        .pivot(on="key", index="index", values="val")
        .with_columns(values.fill_null(0).cum_sum())
        .with_columns(
            values - 
            pl.when(values <= 0, values.shift(1) > 0).then(values).otherwise(0).cum_sum()
        )
        .unpivot(index="index",variable_name="key",value_name="val")
        .filter(pl.col("val") > 0)
        .select(pl.col("index"), pl.struct("key","val"))
        .group_by("index", maintain_order=True).agg("key")
    )
    
    ┌───────┬────────────────────────────────┐
    │ index ┆ key                            │
    │ ---   ┆ ---                            │
    │ str   ┆ list[struct[2]]                │
    ╞═══════╪════════════════════════════════╡
    │ A     ┆ [{"3",20}]                     │
    │ B     ┆ [{"3",20}, {"4",50}]           │
    │ C     ┆ [{"3",28}, {"4",50}]           │
    │ D     ┆ [{"3",28}, {"4",50}, {"5",70}] │
    │ E     ┆ [{"3",28}, {"5",70}]           │
    │ F     ┆ [{"3",28}, {"4",5}, {"5",70}]  │
    └───────┴────────────────────────────────┘