I have a PySpark dataframe:
userid | sku | action |
---|---|---|
123 | 2345 | 2 |
123 | 2345 | 0 |
123 | 5422 | 0 |
123 | 7622 | 0 |
231 | 4322 | 2 |
231 | 4322 | 0 |
231 | 8342 | 0 |
231 | 5342 | 0 |
The output should be like:
userid | sku_pos | sku_neg |
---|---|---|
123 | 2345 | 5422 |
123 | 2345 | 7622 |
231 | 4322 | 8342 |
231 | 4322 | 5342 |
For each distinct "userid" the "sku" which don't have an "action" > 0 will go to column "sku_neg", while the "sku" which has an "action" > 0 will go to column "sku_pos".
A couple of aggregations is needed:
Finally, explode the lists.
Input:
from pyspark.sql import functions as F
df = spark.createDataFrame(
[('123', '2345', 2),
('123', '2345', 0),
('123', '5422', 0),
('123', '7622', 0),
('231', '4322', 2),
('231', '4322', 0),
('231', '8342', 0),
('231', '5342', 0)],
['userid', 'sku', 'action'])
Script:
df = df.groupBy('userid', 'sku').agg(
F.when(F.max('action') > 0, 'p').otherwise('n').alias('_flag')
)
df = (df
.groupBy('userid').pivot('_flag', ['p', 'n']).agg(F.collect_list('sku'))
.withColumn('sku_pos', F.explode('p'))
.withColumn('sku_neg', F.explode('n'))
.drop('p', 'n')
)
df.show()
# +------+-------+-------+
# |userid|sku_pos|sku_neg|
# +------+-------+-------+
# | 231| 4322| 5342|
# | 231| 4322| 8342|
# | 123| 2345| 7622|
# | 123| 2345| 5422|
# +------+-------+-------+