Search code examples
python-polars

Polars: assign existing category


I am using Polars to analyze some A/B test data (and a little bit more...). Now I had to correct for some inconsistency. df_prep is a Polars DataFrame that has a column 'Group' of type cat with levels 'A' and 'B'. Naively, I did this:

# After the A/B test period, everything is B!
df_prep = (df_prep.lazy()
           .with_columns(
               pl.when(pl.col('Datum') >= pl.col('TestEndDate'))
               .then(pl.lit('B'))
               .otherwise(pl.col('Group'))
               .alias('Group'))
           .collect())

However, the problem is now that df_prep['Group'].unique() gives

shape: (3,)
Series: 'Group' [cat]
[
    "B"
    "A"
    "B"
]

This is obviously not what I wanted. I wanted to assign the existing category "B".

How could this be achieved?

EDIT: I found one way:

df_prep = df_prep.with_columns(pl.col('Group').cast(pl.String).cast(pl.Categorical).alias('Group'))

But this doesn't seem right to me... Isn't there a more ideomatic solution?


Solution

  • This is a common problem when comparing string values to Categorical values. One way to solve this problem is to use a string cache, either globally or using a context manager.

    Without a string cache

    First, let's take a closer look at what is occurring. Let's start with this data, and look at the underlying physical representation of the Categorical variable (the integer that represents each unique category value).

    import polars as pl
    from datetime import date
    
    df_prep = pl.DataFrame(
        [
            pl.Series(
                name="Group",
                values=["A", "A", "B", "B"],
                dtype=pl.Categorical,
            ),
            pl.Series(
                name="Datum",
                values=pl.date_range(date(2022, 1, 1), date(2022, 1, 4), "1d", eager=True),
            ),
            pl.Series(name="TestEndDate", values=[date(2022, 1, 4)] * 4),
        ]
    )
    
    (
        df_prep
        .with_columns(pl.col('Group').to_physical().alias('Physical'))
    )
    
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 0        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 0        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 1        │
    └───────┴────────────┴─────────────┴──────────┘
    

    Note that A is a assigned a physical value of 0; B, a value of 1.

    Now, let's run the next step (without a string cache), and see what happens:

    result = (
        df_prep.lazy()
        .with_columns(
            pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
            .then(pl.lit("B"))
            .otherwise(pl.col("Group"))
            .alias("Group")
        )
        .with_columns(pl.col('Group').to_physical().alias('Physical'))
        .collect()
    )
    result
    
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 1        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 2        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 0        │
    └───────┴────────────┴─────────────┴──────────┘
    

    Notice what happened. Without a string cache, the underlying physical representations of the Categorical values have changed. Indeed, the Categorical value B now has two underlying physical representations: 2 and 0. Polars sees the two B's as distinct.

    Indeed, we see this if we use unique on this column:

    result.get_column('Group').unique()
    
    shape: (3,)
    Series: 'Group' [cat]
    [
            "B"
            "A"
            "B"
    ]
    

    Using a global string cache

    One easy way to handle this is to use a global string cache while making comparisons between strings and Categorical values, or setting values for Categorical variables using strings.

    We'll set the global string cache and rerun the algorithm. We'll use Polars' enable_string_cache method to achieve this.

    pl.enable_string_cache()
    df_prep = pl.DataFrame(
        [
            pl.Series(
                name="Group",
                values=["A", "A", "B", "B"],
                dtype=pl.Categorical,
            ),
            pl.Series(
                name="Datum",
                values=pl.date_range(date(2022, 1, 1), date(2022, 1, 4), "1d", eager=True),
            ),
            pl.Series(name="TestEndDate", values=[date(2022, 1, 4)] * 4),
        ]
    )
    
    result = (
        df_prep.lazy()
        .with_columns(
            pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
            .then(pl.lit("B"))
            .otherwise(pl.col("Group"))
            .alias("Group")
        )
        .with_columns(pl.col('Group').to_physical().alias('Physical'))
        .collect()
    )
    result
    
    >>> result
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 0        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 0        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 1        │
    └───────┴────────────┴─────────────┴──────────┘
    
    >>> result.get_column('Group').unique()
    shape: (2,)
    Series: 'Group' [cat]
    [
            "A"
            "B"
    ]
    

    Notice how the Categorical variable maintains its correct physical representation. And the results of using unique on Group are what we expect.

    Using a Context Manager

    If you don't want to keep a global string cache in effect, you can use a context manager to set a localized, temporary StringCache while you are making comparisons to strings.

    with pl.StringCache():
        df_prep = pl.DataFrame(
            [
                pl.Series(
                    name="Group",
                    values=["A", "A", "B", "B"],
                    dtype=pl.Categorical,
                ),
                pl.Series(
                    name="Datum",
                    values=pl.date_range(date(2022, 1, 1), date(2022, 1, 4), "1d", eager=True),
                ),
                pl.Series(name="TestEndDate", values=[date(2022, 1, 4)] * 4),
            ]
        )
    
        result = (
            df_prep.lazy()
            .with_columns(
                pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
                .then(pl.lit("B"))
                .otherwise(pl.col("Group"))
                .alias("Group")
            )
            .with_columns(pl.col('Group').to_physical().alias('Physical'))
            .collect()
        )
        result
    
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 0        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 0        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 1        │
    └───────┴────────────┴─────────────┴──────────┘
    
    >>> result.get_column('Group').unique()
    shape: (2,)
    Series: 'Group' [cat]
    [
            "A"
            "B"
    ]
    

    Edit: Reading/Scanning external files

    You can read/scan external files with a string cache in effect. For example, below I've saved our DataFrame to tmp.parquet.

    If I use read_parquet with a string cache in effect, the Categorical variables are included in the string cache.

    (Note: in the examples below, I'll use a Context Manager -- to clearly delineate where the string cache is in effect.)

    import polars as pl
    with pl.StringCache():
        (
            pl.read_parquet('tmp.parquet')
            .with_columns(
                pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
                .then(pl.lit("B"))
                .otherwise(pl.col("Group"))
                .alias("Group")
            )
            .with_columns(pl.col('Group').to_physical().alias('Physical'))
        )
    
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 0        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 0        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 1        │
    └───────┴────────────┴─────────────┴──────────┘
    

    Notice that our Categorical values are correct. (The B values have the same underlying physical representation.)

    However, if we move the read_parquet method outside the Context Manager (so that the DataFrame is created without a string cache), we have a problem.

    df_prep = pl.read_parquet('tmp.parquet')
    
    with pl.StringCache():
        (
            df_prep
            .with_columns(
                pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
                .then(pl.lit("B"))
                .otherwise(pl.col("Group"))
                .alias("Group")
            )
            .with_columns(pl.col('Group').to_physical().alias('Physical'))
        )
    
    Traceback (most recent call last):
      File "<stdin>", line 7, in <module>
      File "/home/corey/.virtualenvs/StackOverflow/lib/python3.10/site-packages/polars/internals/dataframe/frame.py", line 4027, in with_column
        self.lazy()
      File "/home/corey/.virtualenvs/StackOverflow/lib/python3.10/site-packages/polars/internals/lazyframe/frame.py", line 803, in collect
        return pli.wrap_df(ldf.collect())
    exceptions.ComputeError: cannot combine categorical under a global string cache with a non cached categorical
    

    The error message says it all.

    Edit: Placing existing Categorical columns under a string cache

    One way to correct the situation above (assuming that it's already too late to re-read your DataFrame with a string cache) is to put a new string cache into effect, and then cast the values back to strings and then back to Categorical.

    Below, we'll use a shortcut to perform this for all Categorical columns in parallel - by specifying pl.Categorical in the pl.col.

    with pl.StringCache():
        (
            df_prep
            .with_columns(
                pl.col(pl.Categorical).cast(pl.String).cast(pl.Categorical)
            )
            .with_columns(
                pl.when(pl.col("Datum") >= pl.col("TestEndDate"))
                .then(pl.lit("B"))
                .otherwise(pl.col("Group"))
                .alias("Group")
            )
            .with_columns(pl.col('Group').to_physical().alias('Physical'))
        )
    
    shape: (4, 4)
    ┌───────┬────────────┬─────────────┬──────────┐
    │ Group ┆ Datum      ┆ TestEndDate ┆ Physical │
    │ ---   ┆ ---        ┆ ---         ┆ ---      │
    │ cat   ┆ date       ┆ date        ┆ u32      │
    ╞═══════╪════════════╪═════════════╪══════════╡
    │ A     ┆ 2022-01-01 ┆ 2022-01-04  ┆ 0        │
    │ A     ┆ 2022-01-02 ┆ 2022-01-04  ┆ 0        │
    │ B     ┆ 2022-01-03 ┆ 2022-01-04  ┆ 1        │
    │ B     ┆ 2022-01-04 ┆ 2022-01-04  ┆ 1        │
    └───────┴────────────┴─────────────┴──────────┘
    

    And now our code works correctly again.