How to convert the below pandas code to polars efficiently. The df given below is a sample df. The original df is much larger in size with more than 5M rows of data. I created a pivot df from polars but not sure how to filter it to get the desired data as shown in pandas df_final df.
import polars as pl
import pandas as pd
df = pd.DataFrame({
'col1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
'col2': ['test1', 'test1', 'test1', 'test1', 'test2', 'test2', 'test2',
'test2', 'test3', 'test3', 'test3', 'test3', 'test4', 'test5',
'test1', 'test1', 'test1', 'test3', 'test4'],
'col3': ['t1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1',
't1', 't1', 't1', 't1', 't1','tl','tl','tl'],
'col4': ['input1', 'input2', 'input3', 'input4', 'input1', 'input2',
'input3', 'input4', 'input1', 'input2', 'input3', 'input5',
'input2', 'input6', 'input1', 'input1', 'input2', 'input2',
'input2'],
'col5': ['result1', 'result2', 'result3', 'result4', 'result1', 'result2',
'result3', 'result4', 'result1', 'result2', 'result3', 'result4',
'result2', 'result1', 'result2', 'result6', 'result1', 'result1', 'result1'],
'col6': [10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 50, 20, 100, 10, 10, 20, 20, 20],
'col7': [100.2, 101.2, 102.3, 101.4, 100.0, 103.0, 104.0, 105.0, 102.0,
87.0, 107.0, 110.2, 120.0, 88.0, 106.2, 101.1, 100, 90.2, 110]
})
p_df = df.pivot_table(values='col7', index=['col4', 'col5', 'col6'], columns=['col2'], aggfunc='max')
df_final = p_df[((p_df.groupby(level=0).rank(ascending=False) == 1.).any(axis=1))&(p_df>100).any(axis=1)]
print(df_final)
# Start of polars code
df_p = pl.from_pandas(df)
# Group by columns
p_df = df_p.pivot(on='col2', index=['col4', 'col5', 'col6'], values='col7')
Any input is much appreciated. Thanks in advance.
Starting from
df = pl.DataFrame(
[
pl.Series("col1", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], dtype=pl.Int64),
pl.Series("col2", ['test1', 'test1', 'test1', 'test1', 'test2', 'test2', 'test2', 'test2', 'test3', 'test3', 'test3', 'test3', 'test4', 'test5', 'test1', 'test1'], dtype=pl.Utf8),
pl.Series("col3", ['t1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1', 't1'], dtype=pl.Utf8),
pl.Series("col4", ['input1', 'input2', 'input3', 'input4', 'input1', 'input2', 'input3', 'input4', 'input1', 'input2', 'input3', 'input5', 'input2', 'input6', 'input1', 'input1'], dtype=pl.Utf8),
pl.Series("col5", ['result1', 'result2', 'result3', 'result4', 'result1', 'result2', 'result3', 'result4', 'result1', 'result2', 'result3', 'result4', 'result2', 'result1', 'result2', 'result6'], dtype=pl.Utf8),
pl.Series("col6", [10, 20, 30, 40, 10, 20, 30, 40, 10, 20, 30, 50, 20, 100, 10, 10], dtype=pl.Int64),
pl.Series("col7", [100.2, 101.2, 102.3, 101.4, 100.0, 103.0, 104.0, 105.0, 102.0, 87.0, 107.0, 110.2, 120.0, 88.0, 106.2, 101.1], dtype=pl.Float64),
]
)
(
df
.pivot('col2',index=(idx:=['col4','col5','col6']),values='col7')
.filter(
(pl.any_horizontal(pl.exclude(idx).rank(descending=True).over('col4')==1)) &
(pl.any_horizontal(pl.col('^test.+$')>100))
)
.sort(idx)
)
shape: (6, 8)
┌────────┬─────────┬──────┬───────┬───────┬───────┬───────┬───────┐
│ col4 ┆ col5 ┆ col6 ┆ test1 ┆ test2 ┆ test3 ┆ test4 ┆ test5 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════╪═════════╪══════╪═══════╪═══════╪═══════╪═══════╪═══════╡
│ input1 ┆ result1 ┆ 10 ┆ 100.2 ┆ 100.0 ┆ 102.0 ┆ null ┆ null │
│ input1 ┆ result2 ┆ 10 ┆ 106.2 ┆ null ┆ null ┆ null ┆ null │
│ input2 ┆ result2 ┆ 20 ┆ 101.2 ┆ 103.0 ┆ 87.0 ┆ 120.0 ┆ null │
│ input3 ┆ result3 ┆ 30 ┆ 102.3 ┆ 104.0 ┆ 107.0 ┆ null ┆ null │
│ input4 ┆ result4 ┆ 40 ┆ 101.4 ┆ 105.0 ┆ null ┆ null ┆ null │
│ input5 ┆ result4 ┆ 50 ┆ null ┆ null ┆ 110.2 ┆ null ┆ null │
└────────┴─────────┴──────┴───────┴───────┴───────┴───────┴───────┘
In this example pl.exclude(idx)
and pl.col('^test.+$')
represent the same thing (and are interchangeable), I just put both to show that there are (at least) two ways to refer to those columns.
If for some reason you wanted to filter before the pivot, you could do that with window functions
(
df
.with_columns(zz=(pl.col('col7').rank(descending=True)).over('col4','col2'))
.filter(
(pl.col('zz')==1).over((idx:=['col4','col5','col6']),mapping_strategy='join').list.any(),
(pl.col('col7')>100).over(idx,mapping_strategy='join').list.any()
)
.pivot('col2',index=idx,values='col7')
.sort(idx)
)
shape: (6, 7)
┌────────┬─────────┬──────┬───────┬───────┬───────┬───────┐
│ col4 ┆ col5 ┆ col6 ┆ test1 ┆ test2 ┆ test3 ┆ test4 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════╪═════════╪══════╪═══════╪═══════╪═══════╪═══════╡
│ input1 ┆ result1 ┆ 10 ┆ 100.2 ┆ 100.0 ┆ 102.0 ┆ null │
│ input1 ┆ result2 ┆ 10 ┆ 106.2 ┆ null ┆ null ┆ null │
│ input2 ┆ result2 ┆ 20 ┆ 101.2 ┆ 103.0 ┆ 87.0 ┆ 120.0 │
│ input3 ┆ result3 ┆ 30 ┆ 102.3 ┆ 104.0 ┆ 107.0 ┆ null │
│ input4 ┆ result4 ┆ 40 ┆ 101.4 ┆ 105.0 ┆ null ┆ null │
│ input5 ┆ result4 ┆ 50 ┆ null ┆ null ┆ 110.2 ┆ null │
└────────┴─────────┴──────┴───────┴───────┴───────┴───────┘
By prefiltering, it removes the test5 column since it's otherwise all null.
I think what you want is this:
(
df
.pivot('col2',index=(idx:=['col4','col5','col6']),values='col7')
.filter(pl.any_horizontal(
(pl.col('^test.+$').rank(descending=True).over('col4')==1) &
(pl.col('^test.+$')>100)
))
.sort(idx)
)
shape: (6, 8)
┌────────┬─────────┬──────┬───────┬───────┬───────┬───────┬───────┐
│ col4 ┆ col5 ┆ col6 ┆ test1 ┆ test2 ┆ test3 ┆ test4 ┆ test5 │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ i64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 ┆ f64 │
╞════════╪═════════╪══════╪═══════╪═══════╪═══════╪═══════╪═══════╡
│ input1 ┆ result1 ┆ 10 ┆ 100.2 ┆ 100.0 ┆ 102.0 ┆ null ┆ null │
│ input1 ┆ result2 ┆ 10 ┆ 106.2 ┆ null ┆ null ┆ null ┆ null │
│ input2 ┆ result2 ┆ 20 ┆ 101.2 ┆ 103.0 ┆ 87.0 ┆ 120.0 ┆ null │
│ input3 ┆ result3 ┆ 30 ┆ 102.3 ┆ 104.0 ┆ 107.0 ┆ null ┆ null │
│ input4 ┆ result4 ┆ 40 ┆ 101.4 ┆ 105.0 ┆ null ┆ null ┆ null │
│ input5 ┆ result4 ┆ 50 ┆ null ┆ null ┆ 110.2 ┆ null ┆ null │
└────────┴─────────┴──────┴───────┴───────┴───────┴───────┴───────┘