Is there a possibility to make a pivot for different columns at once in PySpark? I have a dataframe like this:
from pyspark.sql import functions as sf
import pandas as pd
sdf = spark.createDataFrame(
pd.DataFrame([[1, 'str1', 'str4'], [1, 'str1', 'str4'], [1, 'str2', 'str4'], [1, 'str2', 'str5'],
[1, 'str3', 'str5'], [2, 'str2', 'str4'], [2, 'str2', 'str4'], [2, 'str3', 'str4'],
[2, 'str3', 'str5']], columns=['id', 'col1', 'col2'])
)
# +----+------+------+
# | id | col1 | col2 |
# +----+------+------+
# | 1 | str1 | str4 |
# | 1 | str1 | str4 |
# | 1 | str2 | str4 |
# | 1 | str2 | str5 |
# | 1 | str3 | str5 |
# | 2 | str2 | str4 |
# | 2 | str2 | str4 |
# | 2 | str3 | str4 |
# | 2 | str3 | str5 |
# +----+------+------+
I want to pivot it on multiple columns ("col1", "col2", ...) to have a dataframe that looks like this:
+----+-----------+-----------+-----------+-----------+-----------+
| id | col1_str1 | col1_str2 | col1_str3 | col2_str4 | col2_str5 |
+----+-----------+-----------+-----------+-----------+-----------+
| 1 | 2 | 2 | 1 | 3 | 3 |
| 2 | 0 | 2 | 2 | 3 | 1 |
+----+-----------+-----------+-----------+-----------+-----------+
I've found a solution that works:
sdf_pivot_col1 = (
sdf
.groupby('id')
.pivot('col1')
.agg(sf.count('id'))
)
sdf_pivot_col2 = (
sdf
.groupby('id')
.pivot('col2')
.agg(sf.count('id'))
)
sdf_result = (
sdf
.select('id').distinct()
.join(sdf_pivot_col1, on = 'id' , how = 'left')
.join(sdf_pivot_col2, on = 'id' , how = 'left')
).show()
# +---+----+----+----+----+----+
# | id|str1|str2|str3|str4|str5|
# +---+----+----+----+----+----+
# | 1| 2| 2| 1| 3| 2|
# | 2|null| 2| 2| 3| 1|
# +---+----+----+----+----+----+
But I'm looking for a more compact solution.
With the link of @mrjoseph I came up with the following solution: It works, it's more clean, but I still don't like the joins...
def pivot_udf(df, *cols):
mydf = df.select('id').drop_duplicates()
for c in cols:
mydf = mydf.join(
df
.withColumn('combcol',sf.concat(sf.lit('{}_'.format(c)),df[c]))
.groupby('id.pivot('combcol.agg(sf.count(c)),
how = 'left',
on = 'id'
)
return mydf
pivot_udf(sdf, 'col1','col2').show()
+---+---------+---------+---------+---------+---------+
| id|col1_str1|col1_str2|col1_str3|col2_str4|col2_str5|
+---+---------+---------+---------+---------+---------+
| 1| 2| 2| 1| 3| 2|
| 2| null| 2| 2| 3| 1|
+---+---------+---------+---------+---------+---------+