Search code examples
pythonpandaspython-polars

Convert Pandas pivot_table function into Polars pivot Function


I'm trying to convert some python pandas into polars. I'm stuck trying to convert pandas pivot_table function into polars. The following is the working pandas code. I can't seem to get the same behavior with the Polars pivot function. The polars pivot function forces the column parameter and uses the column values as headers instead of the column label as a header. I'm going for the same output below but with Polars instead of Pandas.

df = pd.DataFrame({"obj" : ["ring", "shoe", "ring"], "price":["65", "42", "65"], "value":["53", "55", "54"], "date":["2022-02-07", "2022-01-07", "2022-03-07"]})

table = pd.pivot_table(df, values=['price','value','date'],index=['obj'], aggfunc={'price': pd.Series.nunique,'value':pd.Series.nunique,'date':pd.Series.nunique})

print(table)

Outputs the following:

        date    price     value  
obj  
ring    2       1         2  
shoe    1       1         1

Solution

  • In Polars, we would not use a pivot table for this. Instead, we would use the group_by and agg functions. Using your data, it would be:

    import polars as pl
    df = pl.from_pandas(df)
    
    df.group_by("obj").agg(pl.all().n_unique())
    
    shape: (2, 4)
    ┌──────┬───────┬───────┬──────┐
    │ obj  ┆ price ┆ value ┆ date │
    │ ---  ┆ ---   ┆ ---   ┆ ---  │
    │ str  ┆ u32   ┆ u32   ┆ u32  │
    ╞══════╪═══════╪═══════╪══════╡
    │ ring ┆ 1     ┆ 2     ┆ 2    │
    │ shoe ┆ 1     ┆ 1     ┆ 1    │
    └──────┴───────┴───────┴──────┘
    

    pivot and unpivot

    Where we would use the pivot function in Polars is to summarize a dataset in 'long' format to a dataset in 'wide' format. As an example, let's convert your original dataset to 'long' format using the unpivot function.

    df2 = df.unpivot(index="obj")
    print(df2)
    
    shape: (9, 3)
    ┌──────┬──────────┬────────────┐
    │ obj  ┆ variable ┆ value      │
    │ ---  ┆ ---      ┆ ---        │
    │ str  ┆ str      ┆ str        │
    ╞══════╪══════════╪════════════╡
    │ ring ┆ price    ┆ 65         │
    │ shoe ┆ price    ┆ 42         │
    │ ring ┆ price    ┆ 65         │
    │ ring ┆ value    ┆ 53         │
    │ shoe ┆ value    ┆ 55         │
    │ ring ┆ value    ┆ 54         │
    │ ring ┆ date     ┆ 2022-02-07 │
    │ shoe ┆ date     ┆ 2022-01-07 │
    │ ring ┆ date     ┆ 2022-03-07 │
    └──────┴──────────┴────────────┘
    

    Now let's use pivot to summarize this 'long' format dataset back to one in "wide" format and simply count the number of values.

    df2.pivot(on='variable', index='obj', aggregate_function=pl.len())
    
    shape: (2, 4)
    ┌──────┬──────┬───────┬───────┐
    │ obj  ┆ date ┆ price ┆ value │
    │ ---  ┆ ---  ┆ ---   ┆ ---   │
    │ str  ┆ u32  ┆ u32   ┆ u32   │
    ╞══════╪══════╪═══════╪═══════╡
    │ ring ┆ 2    ┆ 2     ┆ 2     │
    │ shoe ┆ 1    ┆ 1     ┆ 1     │
    └──────┴──────┴───────┴───────┘
    

    Does this help clarify the use of the pivot functionality?