Search code examples
apache-sparkpysparkapache-spark-sqlpivotrename

How to pivot and rename columns based on several grouped columns


I am getting trouble using agg function and renaming results properly. So far I have made the table of the following format.

sheet equipment chamber time value1 value2
a E1 C1 1 11 21
a E1 C1 2 12 22
a E1 C1 3 13 23
b E1 C1 1 14 24
b E1 C1 2 15 25
b E1 C1 3 16 26

I would like to create a statistical table like this:

sheet E1_C1_value1_mean E1_C1_value1_min E1_C1_value1_max E1_C1_value2_mean E1_C1_value2_min E1_C1_value2_max
a 12 11 13 22 21 23
b 15 14 16 25 24 26

Which I would like to groupBy "sheet", "equipment", "chamber" to aggregate mean, min, max values. I also need to rename column by the rule: equip + chamber + aggregation function. There are multiple "equipment" names and "chamber" names.


Solution

  • As pivot in spark only accept single column, therefore you have to concat the column which you want to pivot:

    df = spark.createDataFrame(
        [
            ('a', 'E1', 'C1', 1, 11, 21),
            ('a', 'E1', 'C1', 2, 12, 22),
            ('a', 'E1', 'C1', 3, 13, 23),
            ('b', 'E1', 'C1', 1, 14, 24),
            ('b', 'E1', 'C1', 2, 15, 25),
            ('b', 'E1', 'C1', 3, 16, 26),
        ],
        schema=['sheet', 'equipment', 'chamber', 'time', 'value1', 'value2']
    )
    
    df.printSchema()
    df.show(10, False)
    +-----+---------+-------+----+------+------+
    |sheet|equipment|chamber|time|value1|value2|
    +-----+---------+-------+----+------+------+
    |a    |E1       |C1     |1   |11    |21    |
    |a    |E1       |C1     |2   |12    |22    |
    |a    |E1       |C1     |3   |13    |23    |
    |b    |E1       |C1     |1   |14    |24    |
    |b    |E1       |C1     |2   |15    |25    |
    |b    |E1       |C1     |3   |16    |26    |
    +-----+---------+-------+----+------+------+
    

    Assume there are lots of columns that you want to do the aggregation, you can use a loop to create and prevent the bulky coding:

    aggregation = []
    for col in df.columns[-2:]:
        aggregation += [func.min(col).alias(f"{col}_min"), func.max(col).alias(f"{col}_max"), func.avg(col).alias(f"{col}_mean")]
    
    
    df.withColumn('new_col', func.concat_ws('_', func.col('equipment'), func.col('chamber')))\
        .groupby('sheet')\
        .pivot('new_col')\
        .agg(*aggregation)\
        .orderBy('sheet')\
        .show(100, False)
    +-----+----------------+----------------+-----------------+----------------+----------------+-----------------+
    |sheet|E1_C1_value1_min|E1_C1_value1_max|E1_C1_value1_mean|E1_C1_value2_min|E1_C1_value2_max|E1_C1_value2_mean|
    +-----+----------------+----------------+-----------------+----------------+----------------+-----------------+
    |a    |11              |13              |12.0             |21              |23              |22.0             |
    |b    |14              |16              |15.0             |24              |26              |25.0             |
    +-----+----------------+----------------+-----------------+----------------+----------------+-----------------+