Search code examples
arraysdataframescalaapache-sparkaggregate

What's the best way to group and aggregate an array of objects in a dataframe in scala


An example:
_4 is a collection of count, date and tag that I want to group and sum

|_1 |_2   |_3|_4                                                            |
|100|Scrap|12|{[{1, 2022-12-05, A}, {1, 2022-12-05, B}]}                    |
|100|Scrap|12|{[{1, 2022-12-06, A}]}                                        |
|100|Scrap|15|{[{2, 2022-12-07, A}, {2, 2022-12-02, A}, {2, 2022-12-03, C}]}|
|100|Scrap|15|{[{5, 2022-12-05, A}, {3, 2022-12-05, A}, {5, 2022-12-05, D}]}|

The output I'm hoping for is something like this which groups by the first 3 columns and the third element (tag) in the objects while summing the first element (count).

|UID |Title|Cell|Data                 |
|100 |Scrap|12  |{[{2,A},{1,B}]       |
|100 |Scrap|15  |{[{12,A},{2,C},{5,D}]|

schema of the dataframe looks like this

|-- _1: long (nullable = false)
 |-- _2: string (nullable = true)
 |-- _3: long (nullable = false)
 |-- _4: struct (nullable = true)
 |    |-- data: array (nullable = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- count: integer (nullable = false)
 |    |    |    |-- date: date (nullable = true)
 |    |    |    |-- tag: string (nullable = true)

Solution

  • A straight forward approach would be to flatten the array content of column _4 via inline, followed by a couple of groupBy/agg as shown below:

    import java.sql.Date
    case class Item(count: Int, date: Date, tag: String)
    case class Items(data: Seq[Item])
    
    val df = Seq(
      (100L, "Scrap", 12L, Items(Seq(Item(1, Date.valueOf("2022-12-05"), "A"), Item(1, Date.valueOf("2022-12-05"), "B")))),
      (100L, "Scrap", 12L, Items(Seq(Item(1, Date.valueOf("2022-12-06"), "A")))),
      (100L, "Scrap", 15L, Items(Seq(Item(2, Date.valueOf("2022-12-07"), "A"), Item(2, Date.valueOf("2022-12-02"), "A"), Item(2, Date.valueOf("2022-12-03"), "C")))),
      (100L, "Scrap", 15L, Items(Seq(Item(5, Date.valueOf("2022-12-05"), "A"), Item(3, Date.valueOf("2022-12-05"), "A"), Item(5, Date.valueOf("2022-12-05"), "D"))))
    ).toDF("_1", "_2", "_3", "_4")
    
    df.
      select($"_1", $"_2", $"_3", expr("inline(_4.data)")).
      groupBy($"_1".as("UID"), $"_2".as("Title"), $"_3".as("Cell"), $"tag").agg(
        struct(sum($"count"), first($"tag")).as("TagSum")
      ).
      groupBy("UID", "Title", "Cell").agg(
        collect_list("TagSum").as("Data")
      ).
      show(false)
    /*
    +---+-----+----+-------------------------+
    |UID|Title|Cell|Data                     |
    +---+-----+----+-------------------------+
    |100|Scrap|12  |[{1, B}, {2, A}]         |
    |100|Scrap|15  |[{2, C}, {12, A}, {5, D}]|
    +---+-----+----+-------------------------+
    */
    

    The 1st groupBy groups the dataset by the key columns along with the struct field tag of _4.data elements to sum the count by tag, and the 2nd groupBy groups only by the key columns to aggregate for the wanted result.