Search code examples
pythondataframepython-polars

Multiply Columns Together Based on Condition


Is there a way for me to dynamically multiply columns together based on a value in another column in Python?

I'm using Polars if that makes a difference. For example, if calendar_year is 2018, I'd want to multiply columns 2018, 2019, 2020, and 2021 together, but if calendar_year is 2019, I'd only want to multiply columns 2019, 2020, and 2021 together.

I'd like to store the result in a new column called product. In the future, we'll have additional columns such as 2022, and 2023, so I'd love the ability to have my formula account for these new columns without having to go into the code base each year and add them to my product manually.

df = pl.from_repr("""
┌─────┬───────────────┬───────┬───────┬───────┬───────┬───────┬─────────┐
│ id  ┆ calendar_year ┆ 2017  ┆ 2018  ┆ 2019  ┆ 2020  ┆ 2021  ┆ product │
│ --- ┆ ---           ┆ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---   ┆ ---     │
│ i64 ┆ i64           ┆ f64   ┆ f64   ┆ f64   ┆ f64   ┆ f64   ┆ f64     │
╞═════╪═══════════════╪═══════╪═══════╪═══════╪═══════╪═══════╪═════════╡
│ 123 ┆ 2018          ┆ 0.998 ┆ 0.997 ┆ 0.996 ┆ 0.995 ┆ 0.994 ┆ 0.9801  │
│ 456 ┆ 2019          ┆ 0.993 ┆ 0.992 ┆ 0.991 ┆ 0.99  ┆ 0.989 ┆ 0.9557  │
└─────┴───────────────┴───────┴───────┴───────┴───────┴───────┴─────────┘
""")

Thanks in advance for the help!


Solution

  • It looks like you want to multiply CY factors for all years beyond calendar_year, and not have to update this logic for each year.

    If that's the case, one way to avoid hard-coding the CY selections is to use unpivot and filter the results.

    (
        df
        .select(
            'id',
            'calendar_year',
            pl.col('^20\d\d$')
        )
        .unpivot(
            index=['id', 'calendar_year'],
            variable_name='CY',
            value_name='CY factor',
        )
        .with_columns(pl.col('CY').cast(pl.Int64))
        .filter(pl.col('CY') >= pl.col('calendar_year'))
        .group_by('id')
        .agg(
            pl.col('CY factor').product().alias('product')
        )
    )
    
    shape: (2, 2)
    ┌─────┬──────────┐
    │ id  ┆ product  │
    │ --- ┆ ---      │
    │ i64 ┆ f64      │
    ╞═════╪══════════╡
    │ 456 ┆ 0.970298 │
    │ 123 ┆ 0.982119 │
    └─────┴──────────┘
    

    From there, you can join the result back to your original dataset (using the id column).