Search code examples
scalaapache-spark

Performing a groupBy on a dataframe while limiting the number of rows


I have a dataframe that contains an "id" column and a "publication" column. The "id" column contains duplicates, and represents a researcher. The "publication" column contains some information about an academic work the researcher published.

I want to transform this dataframe to collect the publications into an array, reducing the number of rows. I can do this using groupBy and collect_list. This would make it so that the "id" column would only contain unique values.


    myDataframe
     .groupBy("id")
     .agg(
     collect_list("publication").as("publications")
     ).select("id", "publications")

However, for my purposes, this is too much data for one row. I want to limit the number of publications that are collected, and split the data up across multiple rows.

Let's my dataframe looks like this, where id of 1 appears in 10 rows:

| id  |  publication   |
| ----| -------------- |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 2   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |
| 1   | "foobar"       |

I want to groupBy id and collect publication into a list, but limit this to a maximum of 5 publications per group:

| id  |  publication   |
| ----| -------------- |
| 1   | ["foobar",...] |
| 1   | ["foobar",...] |
| 2   | ["foobar"]     |

How would I accomplish this in spark scala?


Solution

  • If you want a fixed number of publications per row, you have to first calculate an intermediary bucket number per publication per researcher. You can determine the bucket number by integer division of the rank of the publication / 5 (or however many publications you want per list). You can then group on id and bucket number. Here's an example I ran in spark-shell:

    val testDF = Seq( 
    (1, "pub1"), 
    (1, "pub2"), 
    (1, "pub3"), 
    (1, "pub4"), 
    (1, "pub5"), 
    (1, "pub6"), 
    (1, "pub7"), 
    (1, "pub8"), 
    (2, "pub9"), 
    (2, "pub10"), 
    (2, "pub11"), 
    (2, "pub12"), 
    (2, "pub13")).toDF("id", "publication")
    
    testDF.withColumn("rn", row_number().over(Window.partitionBy("id").orderBy("id")) - 1)
    .withColumn("bucket", floor(col("rn") / 5))
    .groupBy("id", "bucket").agg(collect_list("publication").as("publications"))
    .select("id", "publications")
    .show(false)
    

    Output:

    +---+----------------------------------+
    |id |publications                      |
    +---+----------------------------------+
    |1  |[pub1, pub2, pub3, pub4, pub5]    |
    |1  |[pub6, pub7, pub8]                |
    |2  |[pub9, pub10, pub11, pub12, pub13]|
    +---+----------------------------------+