Search code examples
pandaslist-comprehension

pandas categorical remove categories from multiple columns


I have many categorical columns like:

df = pd.DataFrame(
    data={
        "id": [1, 2, 3, 4],
        "category1": [" ",
                      "data",
                      "more data",
                      "         "],
        "category2": ["   ", "more data", " ", "and more"],
    }
)
df["category1"] = df["category1"].astype("category")
df["category2"] = df["category2"].astype("category")

I want to remove any levels of the categorical type columns that only have whitespace, while ensuring they remain categories (can't use .str in other words). I have tried:

cat_cols = df.select_dtypes("category").columns
for c in cat_cols:
    levels = [level for level in df[c].cat.categories.values.tolist()
              if level.isspace()]
    df[c] = df[c].cat.remove_categories(levels)

This works, so I tried making it faster and neater with list-comprehension like so:

df[cat_cols] = [df[c].cat.remove_categories(
                [level for level in df[c].cat.categories.values.tolist()
                if level.isspace()])
                for c in cat_cols]

At which point I get "ValueError: Columns must be same length as key"

Note, I don't want to use inplace parameter in the list-comp because it is going to be deprecated for pd.Categorical.

Feel like I might be missing something basic here, but how do I do this with a list-comprehension and not use inplace?


Solution

  • You can use dictionary comprehension with DataFrame constructor:

    df[cat_cols] = pd.DataFrame({c: df[c].cat.remove_categories(
                    [level for level in df[c].cat.categories.values.tolist()
                    if level.isspace()])
                    for c in cat_cols})
    
    print (df)
       id  category1  category2
    0   1        NaN        NaN
    1   2       data  more data
    2   3  more data        NaN
    3   4        NaN   and more
        
    

    Or use concat:

    df[cat_cols] = pd.concat([df[c].cat.remove_categories(
                    [level for level in df[c].cat.categories.values.tolist()
                    if level.isspace()])
                    for c in cat_cols], axis=1)
    
    print (df)
       id  category1  category2
    0   1        NaN        NaN
    1   2       data  more data
    2   3  more data        NaN
    3   4        NaN   and more