Search code examples
python-3.xdataframepython-polars

Polars pairwise sum of array column


I just got started with Polars (python) so this may be an ignorant question. I have a DF like the image shows where one of the columns (series) contains a numpy array of length 18. I would like to do a groupby on the group column and a pairwise sum aggregation on the series column, but I can't figure out a good way to do that in Polars. I can, of course, just do a map_elements and np.sum the arrays (like in the example) but I'm hoping there is a way to optimize it.

Here is my current implementation which achieves the desired effect but I don't think it is optimal because it uses map_elements. Is there a polars expression that achieve the same thing or is this the best I can do (without learning Rust, which I will someday)?

import polars as pl
import numpy as np
data = [
{'group': 1,
  'series': np.array([ 2398,  2590,  3000,  3731,  3986,  4603,  4146,  4325,  6068,
          6028,  7486,  7759,  8323,  8961,  9598, 10236, 10873, 11511])},
{'group': 1,
  'series': np.array([ 2398,  2590,  3000,  3731,  3986,  4603,  4146,  4325,  6068,
          6028,  7486,  7759,  8323,  8961,  9598, 10236, 10873, 11511])},
 {'group': 2,
  'series': np.array([1132, 1269, 1452, 1687, 1389, 1655, 1532, 1661, 1711, 1528, 1582,
         1638, 1603, 1600, 1597, 1594, 1591, 1588])},
 {'group': 3,
  'series': np.array([ 2802,  3065,  3811,  4823,  4571,  4817,  4668,  5110,  6920,
          7131, 10154, 11138, 11699, 12840, 13981, 15123, 16264, 17405])},
]
df = pl.DataFrame(data)
# this performs the desired aggregation (pairwise sum of 'series' arrays)
# sums first two rows together (group 1), leaves others unchanged
df.group_by('group').agg(
  pl.col('series').map_elements(lambda x: np.sum(x.to_list(), axis=0))
).to_dicts()
'''
desired output

group    series
i64    object
2    [1132 1269 1452 1687 1389 1655 1532 1661 1711 1528 1582 1638 1603 1600
 1597 1594 1591 1588]
1    [ 4796  5180  6000  7462  7972  9206  8292  8650 12136 12056 14972 15518
 16646 17922 19196 20472 21746 23022]
3    [ 2802  3065  3811  4823  4571  4817  4668  5110  6920  7131 10154 11138
 11699 12840 13981 15123 16264 17405]

'''

Thank you in advance for any help.


Solution

  • df = pl.DataFrame([{"group": x["group"], "series": list(x["series"])} for x in data])
    
    l = 18
    
    (
        df
        .group_by("group")
        .agg(
            pl.concat_list(
                pl.col.series.list.get(i).sum() for i in range(l)
            )
        )
    )
    
    shape: (3, 2)
    ┌───────┬───────────────────────┐
    │ group ┆ series                │
    │ ---   ┆ ---                   │
    │ i64   ┆ list[i64]             │
    ╞═══════╪═══════════════════════╡
    │ 1     ┆ [4796, 5180, … 23022] │
    │ 3     ┆ [2802, 3065, … 17405] │
    │ 2     ┆ [1132, 1269, … 1588]  │
    └───────┴───────────────────────┘