Search code examples
pythondataframepython-polars

Implement frequency encoding in polars


I want to replace the categories with their occurrence frequency. My dataframe is lazy and currently I cannot do it without 2 passes over the entire data and then one pass over a column to get the length of the dataframe. Here is how I am doing it:

Input:

df = pl.DataFrame({"a": [1, 8, 3], "b": [4, 5, None], "c": ["foo", "bar", "bar"]}).lazy()
print(df.collect())
output:
shape: (3, 3)
┌─────┬──────┬─────┐
│ a   ┆ b    ┆ c   │
│ --- ┆ ---  ┆ --- │
│ i64 ┆ i64  ┆ str │
╞═════╪══════╪═════╡
│ 1   ┆ 4    ┆ foo │
│ 8   ┆ 5    ┆ bar │
│ 3   ┆ null ┆ bar │
└─────┴──────┴─────┘

Required output:

shape: (3, 3)
┌─────┬──────┬────────────────────┐
│ a   ┆ b    ┆ c                  │
│ --- ┆ ---  ┆ ---                │
│ i64 ┆ i64  ┆ str                │
╞═════╪══════╪════════════════════╡
│ 1   ┆ 4    ┆ 0.3333333333333333 │
│ 8   ┆ 5    ┆ 0.6666666666666666 │
│ 3   ┆ null ┆ 0.6666666666666666 │
└─────┴──────┴────────────────────┘

transformation code:

l = df.select("c").collect().shape[0]
rep = df.group_by("c").len().collect().with_columns(pl.col("len")/l).lazy()
df_out = df.with_context(rep.select(pl.all().name.prefix("context_"))).with_columns(pl.col("c").replace(pl.col("context_c"), pl.col("context_len"))).collect()
print(df_out)
output:
shape: (3, 3)
┌─────┬──────┬────────────────────┐
│ a   ┆ b    ┆ c                  │
│ --- ┆ ---  ┆ ---                │
│ i64 ┆ i64  ┆ str                │
╞═════╪══════╪════════════════════╡
│ 1   ┆ 4    ┆ 0.3333333333333333 │
│ 8   ┆ 5    ┆ 0.6666666666666666 │
│ 3   ┆ null ┆ 0.6666666666666666 │
└─────┴──────┴────────────────────┘

As you can see I am collecting the data 2 times full and there is one collect over a single column. Can I do better?


Solution

  • pl.len() will evaluate to the "column length".

    You can also use it in a group context (agg/over) as a way to count the values.

    df.with_columns(pl.len().over("c") / pl.len()).collect()
    
    shape: (3, 3)
    ┌─────┬──────┬──────────┐
    │ a   ┆ b    ┆ c        │
    │ --- ┆ ---  ┆ ---      │
    │ i64 ┆ i64  ┆ f64      │
    ╞═════╪══════╪══════════╡
    │ 1   ┆ 4    ┆ 0.333333 │
    │ 8   ┆ 5    ┆ 0.666667 │
    │ 3   ┆ null ┆ 0.666667 │
    └─────┴──────┴──────────┘
    

    By grouping by the values, their "frequency count" is the group length.

    >>> df.group_by("c").len()
    shape: (2, 2)
    ┌─────┬─────┐
    │ c   ┆ len │
    │ --- ┆ --- │
    │ cat ┆ u32 │
    ╞═════╪═════╡
    │ foo ┆ 1   │
    │ bar ┆ 2   │
    └─────┴─────┘