Search code examples
pysparkpython-polars

polars groupby and pivot converting code from pyspark


Currently converting some code from pyspark to polars which i need some help with

In pyspark i am grouping by col1 and col2 then pivoting on column called VariableName using the Value column.how would I do this in polars?

pivotDF = df.groupBy("col1","col2").pivot("VariableName").max("Value")

Solution

  • Let's start with this data:

    import polars as pl
    from pyspark.sql import SparkSession
    
    df = pl.DataFrame(
        {
            "col1": ["A", "B"] * 12,
            "col2": ["x", "y", "z"] * 8,
            "VariableName": ["one", "two", "three", "four"] * 6,
            "Value": pl.int_range(0, 24, eager=True),
        }
    )
    df
    
    shape: (24, 4)
    ┌──────┬──────┬──────────────┬───────┐
    │ col1 ┆ col2 ┆ VariableName ┆ Value │
    │ ---  ┆ ---  ┆ ---          ┆ ---   │
    │ str  ┆ str  ┆ str          ┆ i64   │
    ╞══════╪══════╪══════════════╪═══════╡
    │ A    ┆ x    ┆ one          ┆ 0     │
    │ B    ┆ y    ┆ two          ┆ 1     │
    │ A    ┆ z    ┆ three        ┆ 2     │
    │ B    ┆ x    ┆ four         ┆ 3     │
    │ A    ┆ y    ┆ one          ┆ 4     │
    │ …    ┆ …    ┆ …            ┆ …     │
    │ B    ┆ y    ┆ four         ┆ 19    │
    │ A    ┆ z    ┆ one          ┆ 20    │
    │ B    ┆ x    ┆ two          ┆ 21    │
    │ A    ┆ y    ┆ three        ┆ 22    │
    │ B    ┆ z    ┆ four         ┆ 23    │
    └──────┴──────┴──────────────┴───────┘
    

    Running your query on pyspark yields:

    spark = SparkSession.builder.getOrCreate()
    (
        spark
        .createDataFrame(df.to_pandas())
        .groupBy("col1", "col2")
        .pivot("VariableName")
        .max("Value")
        .sort(["col1", "col2"])
        .show()
    )
    
    +----+----+----+----+-----+----+                                                
    |col1|col2|four| one|three| two|
    +----+----+----+----+-----+----+
    |   A|   x|null|  12|   18|null|
    |   A|   y|null|  16|   22|null|
    |   A|   z|null|  20|   14|null|
    |   B|   x|  15|null| null|  21|
    |   B|   y|  19|null| null|  13|
    |   B|   z|  23|null| null|  17|
    +----+----+----+----+-----+----+
    

    In Polars, we would code this using pivot.

    (
        df.pivot(
            on="VariableName", 
            index=["col1", "col2"], 
            aggregate_function=pl.element().max()
        )
        .sort("col1", "col2")
    )
    
    shape: (6, 6)
    ┌──────┬──────┬──────┬──────┬───────┬──────┐
    │ col1 ┆ col2 ┆ one  ┆ two  ┆ three ┆ four │
    │ ---  ┆ ---  ┆ ---  ┆ ---  ┆ ---   ┆ ---  │
    │ str  ┆ str  ┆ i64  ┆ i64  ┆ i64   ┆ i64  │
    ╞══════╪══════╪══════╪══════╪═══════╪══════╡
    │ A    ┆ x    ┆ 12   ┆ null ┆ 18    ┆ null │
    │ A    ┆ y    ┆ 16   ┆ null ┆ 22    ┆ null │
    │ A    ┆ z    ┆ 20   ┆ null ┆ 14    ┆ null │
    │ B    ┆ x    ┆ null ┆ 21   ┆ null  ┆ 15   │
    │ B    ┆ y    ┆ null ┆ 13   ┆ null  ┆ 19   │
    │ B    ┆ z    ┆ null ┆ 17   ┆ null  ┆ 23   │
    └──────┴──────┴──────┴──────┴───────┴──────┘