I'm trying to port a pandas script to polars. I have a dataset that looks like that
df = pl.from_repr("""
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┐
│ sid ┆ roi ┆ endpoint ┆ value ┆ std ┆ voxel_count │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ i64 │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╡
│ 4213-a3_bl ┆ AF_L ┆ afd_along ┆ 0.4 ┆ 0.21 ┆ 57334 │
│ 4213-a3_bl ┆ AF_L ┆ radfODF ┆ 0.08 ┆ 0.045 ┆ 57334 │
│ 4213-a3_bl ┆ AF_R ┆ afd_along ┆ 0.42 ┆ 0.22 ┆ 53916 │
│ 4213-a3_bl ┆ AF_R ┆ radfODF ┆ 0.08 ┆ 0.04 ┆ 53916 │
│ 4213-a3_bl ┆ CC_1 ┆ afd_along ┆ null ┆ null ┆ null │
│ 4213-a3_bl ┆ CC_2a ┆ afd_along ┆ 0.54 ┆ 0.3 ┆ 3264 │
│ 4213-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 602620 │
│ 4225-a3_bl ┆ CC_2a ┆ radfODF ┆ 0.06 ┆ 0.04 ┆ 3264 │
│ 4225-a3_bl ┆ CC_2b ┆ afd_along ┆ 0.47 ┆ 0.24 ┆ 18833 │
│ 4225-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 718758 │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┘
""")
I want to add a column based on a group_by
df.filter(roi = "wm_mask").group_by("sid").first()
shape: (2, 6)
┌────────────┬─────────┬──────────┬───────┬──────┬─────────────┐
│ sid ┆ roi ┆ endpoint ┆ value ┆ std ┆ voxel_count │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ i64 │
╞════════════╪═════════╪══════════╪═══════╪══════╪═════════════╡
│ 4213-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 602620 │
│ 4225-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 718758 │
└────────────┴─────────┴──────────┴───────┴──────┴─────────────┘
Now I want to add this new voxel_count
values that correspond to the right sid
, which should give something like
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┬────────────────┐
│ sid ┆ roi ┆ endpoint ┆ value ┆ std ┆ voxel_count ┆ wm_mask__count │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ i64 ┆ i64 │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╪════════════════╡
│ 4213-a3_bl ┆ AF_L ┆ afd_along ┆ 0.4 ┆ 0.21 ┆ 57334 ┆ 602620 │
│ 4213-a3_bl ┆ AF_L ┆ radfODF ┆ 0.08 ┆ 0.045 ┆ 57334 ┆ 602620 │
│ 4213-a3_bl ┆ AF_R ┆ afd_along ┆ 0.42 ┆ 0.22 ┆ 53916 ┆ 602620 │
│ 4213-a3_bl ┆ AF_R ┆ radfODF ┆ 0.08 ┆ 0.04 ┆ 53916 ┆ 602620 │
│ 4213-a3_bl ┆ CC_1 ┆ afd_along ┆ null ┆ null ┆ null ┆ 602620 │
│ 4213-a3_bl ┆ CC_2a ┆ afd_along ┆ 0.54 ┆ 0.3 ┆ 3264 ┆ 602620 │
│ 4213-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 602620 ┆ 602620 │
│ 4225-a3_bl ┆ CC_2a ┆ radfODF ┆ 0.06 ┆ 0.04 ┆ 3264 ┆ 718758 │
│ 4225-a3_bl ┆ CC_2b ┆ afd_along ┆ 0.47 ┆ 0.24 ┆ 18833 ┆ 718758 │
│ 4225-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 718758 ┆ 718758 │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┴────────────────┘
Can you please tell me how to express that in polars?
If it can help, this is how I would do it in pandas:
pd_df = df.to_pandas()
pd_df = pd_df.set_index("sid", drop=True)
pd_df_wm_volumes = pd_df[pd_df.roi == "wm_mask"].groupby("sid", as_index=True).first()
pd_df["wm_mask__count"] = pd_df_wm_volumes["voxel_count"]
pd_df = pd_df.reset_index(drop=False)
We can most easily accomplish this with a left join
in Polars.
First, I'll add the following two lines to your input (so that we have some rows where roi
== 'wm_mask'.
4213-a3_bl,wm_mask,,,,602620
4225-a3_bl,wm_mask,,,,718758
So that our data looks like:
shape: (11, 6)
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┐
│ sid ┆ roi ┆ endpoint ┆ value ┆ std ┆ voxel_count │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ i64 │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╡
│ 4213-a3_bl ┆ AF_L ┆ afd_along ┆ 0.4 ┆ 0.21 ┆ 57334 │
│ 4213-a3_bl ┆ AF_L ┆ radfODF ┆ 0.08 ┆ 0.045 ┆ 57334 │
│ 4213-a3_bl ┆ AF_R ┆ afd_along ┆ 0.42 ┆ 0.22 ┆ 53916 │
│ 4213-a3_bl ┆ AF_R ┆ radfODF ┆ 0.08 ┆ 0.04 ┆ 53916 │
│ 4213-a3_bl ┆ CC_1 ┆ afd_along ┆ null ┆ null ┆ null │
│ 4213-a3_bl ┆ CC_1 ┆ radfODF ┆ null ┆ null ┆ null │
│ 4213-a3_bl ┆ CC_2a ┆ afd_along ┆ 0.54 ┆ 0.3 ┆ 3264 │
│ 4213-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 602620 │
│ 4225-a3_bl ┆ CC_2a ┆ radfODF ┆ 0.06 ┆ 0.04 ┆ 3264 │
│ 4225-a3_bl ┆ CC_2b ┆ afd_along ┆ 0.47 ┆ 0.24 ┆ 18833 │
│ 4225-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 718758 │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┘
First, we'll run the groupby statement to obtain our wm_mask__count
values. I've changed your groupby to something that is more idiomatic of Polars.
mask_counts = (
df
.filter(pl.col('roi') == 'wm_mask')
.group_by('sid')
.agg(
pl.col('voxel_count').first().alias('wm_mask__count')
)
)
mask_counts
shape: (2, 2)
┌────────────┬────────────────┐
│ sid ┆ wm_mask__count │
│ --- ┆ --- │
│ str ┆ i64 │
╞════════════╪════════════════╡
│ 4225-a3_bl ┆ 718758 │
│ 4213-a3_bl ┆ 602620 │
└────────────┴────────────────┘
And then we'll use a "left" join to merge the result back into the original data:
df.join(
mask_counts,
on=['sid'],
how='left',
)
shape: (11, 7)
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┬─────────────┐
│ sid ┆ roi ┆ endpoint ┆ value ┆ std ┆ voxel_count ┆ wm_mask__co │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ unt │
│ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ i64 ┆ --- │
│ ┆ ┆ ┆ ┆ ┆ ┆ i64 │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╪═════════════╡
│ 4213-a3_bl ┆ AF_L ┆ afd_along ┆ 0.4 ┆ 0.21 ┆ 57334 ┆ 602620 │
│ 4213-a3_bl ┆ AF_L ┆ radfODF ┆ 0.08 ┆ 0.045 ┆ 57334 ┆ 602620 │
│ 4213-a3_bl ┆ AF_R ┆ afd_along ┆ 0.42 ┆ 0.22 ┆ 53916 ┆ 602620 │
│ 4213-a3_bl ┆ AF_R ┆ radfODF ┆ 0.08 ┆ 0.04 ┆ 53916 ┆ 602620 │
│ 4213-a3_bl ┆ CC_1 ┆ afd_along ┆ null ┆ null ┆ null ┆ 602620 │
│ 4213-a3_bl ┆ CC_1 ┆ radfODF ┆ null ┆ null ┆ null ┆ 602620 │
│ 4213-a3_bl ┆ CC_2a ┆ afd_along ┆ 0.54 ┆ 0.3 ┆ 3264 ┆ 602620 │
│ 4213-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 602620 ┆ 602620 │
│ 4225-a3_bl ┆ CC_2a ┆ radfODF ┆ 0.06 ┆ 0.04 ┆ 3264 ┆ 718758 │
│ 4225-a3_bl ┆ CC_2b ┆ afd_along ┆ 0.47 ┆ 0.24 ┆ 18833 ┆ 718758 │
│ 4225-a3_bl ┆ wm_mask ┆ null ┆ null ┆ null ┆ 718758 ┆ 718758 │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┴─────────────┘