Search code examples
pythonpython-polars

How to repeat a specific value 'n' times based on other column's list length?


I would like to repeat a None value 'n' times, but 'n' should be defined by other column's list length. I give some simple code example to better illustrate this:

import polars as pl

# Create dummy LazyFrame
lf = pl.LazyFrame(
    {
        "col1": [[1, 2, 3], [1, 2], [1]],
        "col2": [["A", "B", "C"], ["C"], ["D", "E"]],
    }
)

# Print DataFrame
print(lf.collect())
shape: (3, 2)
┌───────────┬─────────────────┐
│ col1      ┆ col2            │
│ ---       ┆ ---             │
│ list[i64] ┆ list[str]       │
╞═══════════╪═════════════════╡
│ [1, 2, 3] ┆ ["A", "B", "C"] │
│ [1, 2]    ┆ ["C"]           │
│ [1]       ┆ ["D", "E"]      │
└───────────┴─────────────────┘

My attempt:

# Define condition
condition = pl.col("col2").list.len().eq(pl.col("col1").list.len())

# Apply condition (does not work as I expected)
lf = lf.with_columns(
    pl.when(condition)
    .then(pl.col("col2"))
    .otherwise(
        pl.repeat(
            value=None,
            n=pl.col("col1").list.len(),
        )
    )
    .alias("col3")
).collect()

# Output
print(lf)
shape: (3, 3)
┌───────────┬─────────────────┬─────────────────┐
│ col1      ┆ col2            ┆ col3            │
│ ---       ┆ ---             ┆ ---             │
│ list[i64] ┆ list[str]       ┆ list[str]       │
╞═══════════╪═════════════════╪═════════════════╡
│ [1, 2, 3] ┆ ["A", "B", "C"] ┆ ["A", "B", "C"] │
│ [1, 2]    ┆ ["C"]           ┆ null            │
│ [1]       ┆ ["D", "E"]      ┆ null            │
└───────────┴─────────────────┴─────────────────┘

But my expected result is:

shape: (3, 3)
┌───────────┬─────────────────┬─────────────────┐
│ col1      ┆ col2            ┆ col3            │
│ ---       ┆ ---             ┆ ---             │
│ list[i64] ┆ list[str]       ┆ list[str]       │
╞═══════════╪═════════════════╪═════════════════╡
│ [1, 2, 3] ┆ ["A", "B", "C"] ┆ ["A", "B", "C"] │
│ [1, 2]    ┆ ["C"]           ┆ [null, null]    │
│ [1]       ┆ ["D", "E"]      ┆ [null]          │
└───────────┴─────────────────┴─────────────────┘

Where col3 has lists with the same length as col1 lists for each row.


Solution

  • You can use .repeat_by()

    df.with_columns(
       pl.when(pl.col("col1").list.len() == pl.col("col2").list.len())
         .then("col2")
         .otherwise(pl.lit(None, pl.String).repeat_by(pl.col("col1").list.len()))
         .alias("col3")
    )
    
    shape: (3, 3)
    ┌───────────┬─────────────────┬─────────────────┐
    │ col1      ┆ col2            ┆ col3            │
    │ ---       ┆ ---             ┆ ---             │
    │ list[i64] ┆ list[str]       ┆ list[str]       │
    ╞═══════════╪═════════════════╪═════════════════╡
    │ [1, 2, 3] ┆ ["A", "B", "C"] ┆ ["A", "B", "C"] │
    │ [1, 2]    ┆ ["C"]           ┆ [null, null]    │
    │ [1]       ┆ ["D", "E"]      ┆ [null]          │
    └───────────┴─────────────────┴─────────────────┘