Search code examples
dataframepysparkstackunpivot

Using "Select Expr" and "Stack" to Unpivot PySpark DataFrame doesn't produce expected results


I am trying to unpivot a PySpark DataFrame, but I don't get the correct results.

Sample dataset:

# Prepare Data
data = [("Spain", 101, 201, 301), \
        ("Taiwan", 102, 202, 302), \
        ("Italy", 103, 203, 303), \
        ("China", 104, 204, 304)
  ]
 
# Create DataFrame
columns= ["Country", "2018", "2019", "2020"]
df = spark.createDataFrame(data = data, schema = columns)
df.show(truncate=False)

enter image description here

Below is the commands I have tried:

from pyspark.sql import functions as F

unpivotExpr = "stack(3, '2018', 2018, '2019', 2019, '2020', 2020) as (Year, CPI)"
unPivotDF = df.select("Country", F.expr(unpivotExpr))
unPivotDF.show()

And the results:

enter image description here

As you can see in the above image, value of column "CPI" is the same as column "Year" which is not expected. The expected result is below:

enter image description here

Value of column "CPI" is get from each row of the pivoted table for the corresponding country.

Any idea to solve this issue?


Solution

  • UPDATE

    Your "stack" expression is correct - just that to work with numbers as column names (2018, 2019 etc.), enclose them in back-ticks:

    unpivotExpr = "stack(3, '2018', `2018`, '2019', `2019`, '2020', `2020`) as (Year, CPI)"
    

    ALTERNATE SOLUTION

    Create a map with key as column name and value as column value, and then explode the map:

    import pyspark.sql.functions as F
    
    df = df.withColumn("year_cpi_map", F.create_map( \
                                          F.lit("2018"), F.col("2018"), \
                                          F.lit("2019"), F.col("2019"), \
                                          F.lit("2020"), F.col("2020") \
            )) \
            .select("Country", F.explode("year_cpi_map").alias("Year", "CPI"))
    

    Or to generalize:

    import pyspark.sql.functions as F
    import itertools
    
    df = df.withColumn("year_cpi_map", F.create_map(list(itertools.chain(*[(F.lit(c), F.col(c)) for c in df.columns if c != "Country"])))) \
           .select("Country", F.explode("year_cpi_map").alias("Year", "CPI"))
    

    Output:

    +-------+----+---+
    |Country|Year|CPI|
    +-------+----+---+
    |Spain  |2018|101|
    |Spain  |2019|201|
    |Spain  |2020|301|
    |Taiwan |2018|102|
    |Taiwan |2019|202|
    |Taiwan |2020|302|
    |Italy  |2018|103|
    |Italy  |2019|203|
    |Italy  |2020|303|
    |China  |2018|104|
    |China  |2019|204|
    |China  |2020|304|
    +-------+----+---+