Search code examples
pythondataframeapache-sparkpysparkaggregate

How to aggregate on several string columns in a pyspark dataframe groupby object?


I have a pyspark dataframe with many columns of string and double types on the following form:

+---+---+--------+-----+------+----+
|tin|ecu|DcyStart|did1 |did2  |did3|
+---+---+--------+-----+------+----+
|1  |1  |1       |34   |null  |null|
|1  |1  |2       |null |2     |null|
|1  |1  |3       |null |null  |b   |
|1  |1  |4       |null |null  |null|
|1  |2  |1       |40   |null  |null|
|1  |2  |2       |null |2     |null|
|1  |2  |3       |null |null  |f   |
|1  |2  |4       |null |null  |null|
+---+---+--------+----+-------+----+

Where each of the did-columns should only have values in rows with a certain value of DcyStart. I am not interested in the information in DcyStart, and i would like to remove it to reduce the table size and get rid of null entries.

I tried grouping on tin and ecu and then aggregating all did-columns over the range of DcyStart with different functions like first(), max() and so on, but these functions fail because of one of two reasons:

  • agg function can not handle strings
  • agg function can only take one column as argument

I have tried several variations of the code below:

list_of_dids = ["did1", "did2", "did3"]
data.groupBy("tin", "ecu").first(*list_of_dids)

but it always gives me one of the two errors listed above.

There are 100+ different did-columns, and some of then could possibly have values for more than one DcyStart, but if that is the case any of them would do for the "aggregation".

What i would like to achieve is this:

+---+----+-----+----+----+
|tin|ecu||did1 |did2|did3|
+---+----+-----+----+----+
|1  |1   |34   |2   |b   |
|1  |2   |40   |2   |f   |
+---+----+-----+----+----+

How the hell do i solve this? o_O


Solution

  • list_of_dids = ["did1", "did2", "did3"]
    from pyspark.sql.functions import first, max
    list(map(lambda x: max(x), list_of_dids))
    df.groupBy("tin", "ecu").agg(*list(map(lambda x: max(x), list_of_dids))).show()
    

    try this. groupBy().agg() and you can do multiple aggregations in one run.