Search code examples
python-polars

How to precalculate expensive Expressions in Polars (in groupby-s and in general)?


I'm having a hard time dealing with the fact that in a group_by I cant efficiently catch a group sub-dataframe with an Expr, perform an expensive operation with it once and then return several different aggregations. I can sort of do it (see example), but my solution is unreadable and looks like Im dealing with an unnecessary overhead because of all those lists. Is there a proper or a completely different way to do it?

Take a look at this example:

import polars as pl
import numpy as np

df = pl.DataFrame(np.random.randint(0,10,size=(1000000, 3)))
expensive = pl.col('column_1').cum_prod().ewm_std(span=10).alias('expensive')
%%timeit
(
df
.group_by('column_0')
.agg(
    expensive.sum().alias('sum'),
    expensive.median().alias('median'),
    *[expensive.max().alias(f'max{x}') for x in range(10)],
    
)
)

417 ms ± 38.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
(
    df
    .group_by('column_0')
    .agg(expensive)
    .select(
        pl.col('expensive').list.eval(pl.element().sum()).arr.first().alias('sum'),
        pl.col('expensive').list.eval(pl.element().median()).arr.first().alias('median'),
        *[pl.col('expensive').list.eval(pl.element().max()).arr.first().alias(f'max{x}') for x in range(10)]
    )
)

95.5 ms ± 9.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

We can see that precomputing the expensive part is beneficial, but actually doing it involves this .list.eval(pl.element().<aggfunc>()).arr.first() that really bothers me because of both readability and flexibility. Try as I might, I cant see a better solution.

I'm not sure whether the problem is just about group_bys, if your solution involves dealing with selects, please share that also.


Solution

  • Use explode instead of list.eval like this:

    %%timeit
    df \
        .group_by('column_0') \
        .agg(expensive).explode('expensive').group_by('column_0').agg(
            pl.col('expensive').sum().alias('sum'),
            pl.col('expensive').median().alias('median'),
            *[pl.col('expensive').max().alias(f'max{x}') for x in range(10)]
        )
    

    On my machine the run times were

    Your first example: 320 ms ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

    Your second: 80.8 ms ± 1.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

    Mine: 63 ms ± 507 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

    Another method which turns out to be slightly slower than the above is to do the expensive expression has a window function which then skips the explode

    %%timeit
    df.select('column_0',expensive.over('column_0')).group_by('column_0').agg(
            pl.col('expensive').sum().alias('sum'),
            pl.col('expensive').median().alias('median'),
            *[pl.col('expensive').max().alias(f'max{x}') for x in range(10)]
        )
    

    This last one returned in 69.7 ms ± 911 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)