Search code examples
python-polars

Add column based on group_by


I'm trying to port a pandas script to polars. I have a dataset that looks like that

df = pl.from_repr("""
┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┐
│ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count │
│ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         │
│ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╡
│ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       │
│ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       │
│ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       │
│ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       │
│ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        │
│ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        │
│ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      │
│ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        │
│ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       │
│ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┘
""")

I want to add a column based on a group_by

df.filter(roi = "wm_mask").group_by("sid").first()
shape: (2, 6)
┌────────────┬─────────┬──────────┬───────┬──────┬─────────────┐
│ sid        ┆ roi     ┆ endpoint ┆ value ┆ std  ┆ voxel_count │
│ ---        ┆ ---     ┆ ---      ┆ ---   ┆ ---  ┆ ---         │
│ str        ┆ str     ┆ str      ┆ f64   ┆ f64  ┆ i64         │
╞════════════╪═════════╪══════════╪═══════╪══════╪═════════════╡
│ 4213-a3_bl ┆ wm_mask ┆ null     ┆ null  ┆ null ┆ 602620      │
│ 4225-a3_bl ┆ wm_mask ┆ null     ┆ null  ┆ null ┆ 718758      │
└────────────┴─────────┴──────────┴───────┴──────┴─────────────┘

Now I want to add this new voxel_count values that correspond to the right sid, which should give something like

┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┬────────────────┐
│ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count ┆ wm_mask__count │
│ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         ┆ ---            │
│ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         ┆ i64            │
╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╪════════════════╡
│ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       ┆ 602620         │
│ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       ┆ 602620         │
│ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       ┆ 602620         │
│ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       ┆ 602620         │
│ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        ┆ 602620         │
│ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        ┆ 602620         │
│ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      ┆ 602620         │
│ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        ┆ 718758         │
│ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       ┆ 718758         │
│ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      ┆ 718758         │
└────────────┴─────────┴───────────┴───────┴───────┴─────────────┴────────────────┘

Can you please tell me how to express that in polars?

If it can help, this is how I would do it in pandas:

pd_df = df.to_pandas()
pd_df = pd_df.set_index("sid", drop=True)
pd_df_wm_volumes = pd_df[pd_df.roi == "wm_mask"].groupby("sid", as_index=True).first()
pd_df["wm_mask__count"] = pd_df_wm_volumes["voxel_count"]
pd_df = pd_df.reset_index(drop=False)

Solution

  • We can most easily accomplish this with a left join in Polars.

    First, I'll add the following two lines to your input (so that we have some rows where roi == 'wm_mask'.

    4213-a3_bl,wm_mask,,,,602620
    4225-a3_bl,wm_mask,,,,718758
    

    So that our data looks like:

    shape: (11, 6)
    ┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┐
    │ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count │
    │ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         │
    │ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         │
    ╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╡
    │ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       │
    │ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       │
    │ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       │
    │ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       │
    │ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        │
    │ 4213-a3_bl ┆ CC_1    ┆ radfODF   ┆ null  ┆ null  ┆ null        │
    │ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        │
    │ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      │
    │ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        │
    │ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       │
    │ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      │
    └────────────┴─────────┴───────────┴───────┴───────┴─────────────┘
    

    First, we'll run the groupby statement to obtain our wm_mask__count values. I've changed your groupby to something that is more idiomatic of Polars.

    mask_counts = (
        df
        .filter(pl.col('roi') == 'wm_mask')
        .group_by('sid')
        .agg(
            pl.col('voxel_count').first().alias('wm_mask__count')
        )
    )
    mask_counts
    
    shape: (2, 2)
    ┌────────────┬────────────────┐
    │ sid        ┆ wm_mask__count │
    │ ---        ┆ ---            │
    │ str        ┆ i64            │
    ╞════════════╪════════════════╡
    │ 4225-a3_bl ┆ 718758         │
    │ 4213-a3_bl ┆ 602620         │
    └────────────┴────────────────┘
    

    And then we'll use a "left" join to merge the result back into the original data:

    df.join(
        mask_counts,
        on=['sid'],
        how='left',
    )
    
    shape: (11, 7)
    ┌────────────┬─────────┬───────────┬───────┬───────┬─────────────┬─────────────┐
    │ sid        ┆ roi     ┆ endpoint  ┆ value ┆ std   ┆ voxel_count ┆ wm_mask__co │
    │ ---        ┆ ---     ┆ ---       ┆ ---   ┆ ---   ┆ ---         ┆ unt         │
    │ str        ┆ str     ┆ str       ┆ f64   ┆ f64   ┆ i64         ┆ ---         │
    │            ┆         ┆           ┆       ┆       ┆             ┆ i64         │
    ╞════════════╪═════════╪═══════════╪═══════╪═══════╪═════════════╪═════════════╡
    │ 4213-a3_bl ┆ AF_L    ┆ afd_along ┆ 0.4   ┆ 0.21  ┆ 57334       ┆ 602620      │
    │ 4213-a3_bl ┆ AF_L    ┆ radfODF   ┆ 0.08  ┆ 0.045 ┆ 57334       ┆ 602620      │
    │ 4213-a3_bl ┆ AF_R    ┆ afd_along ┆ 0.42  ┆ 0.22  ┆ 53916       ┆ 602620      │
    │ 4213-a3_bl ┆ AF_R    ┆ radfODF   ┆ 0.08  ┆ 0.04  ┆ 53916       ┆ 602620      │
    │ 4213-a3_bl ┆ CC_1    ┆ afd_along ┆ null  ┆ null  ┆ null        ┆ 602620      │
    │ 4213-a3_bl ┆ CC_1    ┆ radfODF   ┆ null  ┆ null  ┆ null        ┆ 602620      │
    │ 4213-a3_bl ┆ CC_2a   ┆ afd_along ┆ 0.54  ┆ 0.3   ┆ 3264        ┆ 602620      │
    │ 4213-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 602620      ┆ 602620      │
    │ 4225-a3_bl ┆ CC_2a   ┆ radfODF   ┆ 0.06  ┆ 0.04  ┆ 3264        ┆ 718758      │
    │ 4225-a3_bl ┆ CC_2b   ┆ afd_along ┆ 0.47  ┆ 0.24  ┆ 18833       ┆ 718758      │
    │ 4225-a3_bl ┆ wm_mask ┆ null      ┆ null  ┆ null  ┆ 718758      ┆ 718758      │
    └────────────┴─────────┴───────────┴───────┴───────┴─────────────┴─────────────┘