Search code examples
group-bypysparksetcollect

pyspark collect_set of column outside of groupby


I am trying to use collect_set to get a list of strings of categorie_names that are NOT part of groupby. My code is

from pyspark import SparkContext
from pyspark.sql import HiveContext
from pyspark.sql import functions as F

sc = SparkContext("local")
sqlContext = HiveContext(sc)
df = sqlContext.createDataFrame([
     ("1", "cat1", "Dept1", "product1", 7),
     ("2", "cat2", "Dept1", "product1", 100),
     ("3", "cat2", "Dept1", "product2", 3),
     ("4", "cat1", "Dept2", "product3", 5),
    ], ["id", "category_name", "department_id", "product_id", "value"])

df.show()
df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .show()

#            .agg( F.collect_set("category_name"))\

The output is

+---+-------------+-------------+----------+-----+
| id|category_name|department_id|product_id|value|
+---+-------------+-------------+----------+-----+
|  1|         cat1|        Dept1|  product1|    7|
|  2|         cat2|        Dept1|  product1|  100|
|  3|         cat2|        Dept1|  product2|    3|
|  4|         cat1|        Dept2|  product3|    5|
+---+-------------+-------------+----------+-----+

+-------------+----------+----------+
|department_id|product_id|sum(value)|
+-------------+----------+----------+
|        Dept1|  product2|         3|
|        Dept1|  product1|       107|
|        Dept2|  product3|         5|
+-------------+----------+----------+

I want to have this output

+-------------+----------+----------+----------------------------+
|department_id|product_id|sum(value)| collect_list(category_name)|
+-------------+----------+----------+----------------------------+
|        Dept1|  product2|         3|  cat2                      |
|        Dept1|  product1|       107|  cat1, cat2                |
|        Dept2|  product3|         5|  cat1                      |
+-------------+----------+----------+----------------------------+

Attempt 1

df.groupby("department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

I got this error:

pyspark.sql.utils.AnalysisException: "cannot resolve 'category_name' given input columns: [department_id, product_id, sum(value)];;\n'Aggregate [collect_set('category_name, 0, 0) AS collect_set(category_name)#35]\n+- Aggregate [department_id#2, product_id#3], [department_id#2, product_id#3, sum(value#4L) AS sum(value)#24L]\n +- LogicalRDD [id#0, category_name#1, department_id#2, product_id#3, value#4L]\n"

Attempt 2 I put category_name as part of groupby

df.groupby("category_name", "department_id", "product_id")\
    .agg({'value': 'sum'}) \
    .agg(F.collect_set("category_name")) \
    .show()

It works but output is not correct

+--------------------------+
|collect_set(category_name)|
+--------------------------+
|              [cat1, cat2]|
+--------------------------+

Solution

  • You can specify multiple aggregations within one agg(). The correct syntax for your case would be:

    df.groupby("department_id", "product_id")\
        .agg(F.sum('value'), F.collect_set("category_name"))\
        .show()
    #+-------------+----------+----------+--------------------------+
    #|department_id|product_id|sum(value)|collect_set(category_name)|
    #+-------------+----------+----------+--------------------------+
    #|        Dept1|  product2|         3|                    [cat2]|
    #|        Dept1|  product1|       107|              [cat1, cat2]|
    #|        Dept2|  product3|         5|                    [cat1]|
    #+-------------+----------+----------+--------------------------+
    

    Your method doesn't work, because the first .agg() works on a pyspark.sql.group.GroupedData and returns a new DataFrame. The subsequent call to agg is actually pyspark.sql.DataFrame.agg which is

    shorthand for df.groupBy.agg()

    So essentially the second call to agg is grouping again, which is not what you intended.