Search code examples
scaladataframedictionaryuser-defined-functionsfoldleft

Scala new Map column from integer and string columns


Problem Statement:

I have a dataframe with four columns: service (String), show (String), country_1 (Integer), & country_2 (Integer). My objective is to produce a dataframe that consists of just two columns: service (String) & information (Map[Integer, List[String]])

where the map could contain multiple records of key-value pairs like this per streaming service:

{
    "34521": ["The Crown", "Bridgerton", "The Queen's Gambit"],
    "49678": ["The Crown", "Bridgerton", "The Queen's Gambit"]
}

One important thing to note is that in the future, more countries can be added, for example another few columns in the input dataframe like "country_3", "country_4", etc. The objective with solution code is to also hopefully account for these things and not just hardcode selected columns like I had done in my attempted solution below, if that makes sense.

Input Dataframe:

Schema:

root
|-- service: string (nullable = true)
|-- show: string (nullable = true)
|-- country_1: integer (nullable = true)
|-- country_2: integer (nullable = true)

Dataframe:

service     |      show        |   country_1   |   country_2

Netflix      The Crown               34521           49678
Netflix      Bridgerton              34521           49678
Netflix      The Queen's Gambit      34521           49678
Peacock      The Office              34521           49678
Disney+      WandaVision             34521           49678 
Disney+      Marvel's 616            34521           49678
Disney+      The Mandalorian         34521           49678
Apple TV     Ted Lasso               34521           49678
Apple TV     The Morning Show        34521           49678

Output Dataframe:

Schema:

root
|-- service: string (nullable = true)
|-- information: map (nullable = false)
|    |-- key: integer
|    |-- value: array (valueContainsNull = true)
|    |    |-- element: string (containsNull = true)

Dataframe:

service    |  information          

Netflix    [34521 -> [The Crown, Bridgerton, The Queen’s Gambit], 49678 -> [The Crown, Bridgerton, The Queen’s Gambit]] 
Peacock    [34521 -> [The Office], 49678 -> [The Office]]
Disney+    [34521 -> [WandaVision, Marvel’s 616, The Mandalorian], 49678 -> [WandaVision, Marvel’s 616, The Mandalorian]]
Apple TV   [34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]

What I have tried already

While I've successfully produced my desired output with the code snippet pasted, I don’t want to rely on using very basic SQL-type commands since I don't think it's always optimal for fast computations with large datasets, and additionally, I don’t want to rely on a method where I’m manually selecting the country columns by the exact name when mapping because that can always change in the sense that more country columns can be added later.

Is there a much better way of doing this that utilizes udfs, foldLeft, etc. type of code or anything else that helps with optimization and also helps the code be more concise and not as messy?

val df = spark.read.parquet("filepath/*.parquet") 
val temp = df.groupBy("service", "country_1", "country_2").agg(collect_list("show").alias("show"))
val service_information = grouped.withColumn("information", map(lit($"country_1"), $"show", lit($"country_2"), $"show")).drop("country_1", "country_2", "show")

Solution

  • As per the country data "specs" described in the comments section (i.e. country code will be identical and non-null in all rows for any given country_X column), your code can be generalized to handle arbitrarily many country columns:

    val df = Seq(
      ("Netflix",     "The Crown",             34521,    49678),
      ("Netflix",     "Bridgerton",            34521,    49678),
      ("Netflix",     "The Queen's Gambit",    34521,    49678),
      ("Peacock",     "The Office",            34521,    49678),
      ("Disney+",     "WandaVision",           34521,    49678),
      ("Disney+",     "Marvel's 616",          34521,    49678),
      ("Disney+",     "The Mandalorian",       34521,    49678),
      ("Apple TV",    "Ted Lasso",             34521,    49678),
      ("Apple TV",    "The Morning Show",      34521,    49678)
    ).toDF("service", "show", "country_1", "country_2")
    
    val countryCols = df.columns.filter(_.startsWith("country_")).toList
    
    val grouped = df.groupBy("service", countryCols: _*).agg(collect_list("show").as("shows"))
    
    val service_information = grouped.withColumn(
        "information",
        map( countryCols.flatMap{ c => col(c) :: col("shows") :: Nil }: _* )
      ).drop("shows" :: countryCols: _*)
    
    service_information.show(false)
    // +--------+--------------------------------------------------------------------------------------------------------------+
    // |service |information                                                                                                   |
    // +--------+--------------------------------------------------------------------------------------------------------------+
    // |Disney+ |[34521 -> [WandaVision, Marvel's 616, The Mandalorian], 49678 -> [WandaVision, Marvel's 616, The Mandalorian]]|
    // |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                |
    // |Netflix |[34521 -> [The Crown, Bridgerton, The Queen's Gambit], 49678 -> [The Crown, Bridgerton, The Queen's Gambit]]  |
    // |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso, The Morning Show]]                              |
    // +--------+--------------------------------------------------------------------------------------------------------------+
    

    Note that the described country "specs" would mandate all shows to be associated with the same list of countries. For instance, if you have 3 country_Xs columns and every row of a given country_X is identical without nulls, that means every show is tied to those 3 countries. What if you have a show available only for 2 of the 3 countries?


    In case your data schema could be revised, a more flexible way of maintaining the associated country info would be to have a single ArrayType column for every show.

    val df = Seq(
      ("Netflix",     "The Crown",             Seq(34521, 49678)),
      ("Netflix",     "Bridgerton",            Seq(34521)),
      ("Netflix",     "The Queen's Gambit",    Seq(10001, 49678)),
      ("Peacock",     "The Office",            Seq(34521, 49678)),
      ("Disney+",     "WandaVision",           Seq(10001, 20002, 34521)),
      ("Disney+",     "Marvel's 616",          Seq(49678)),
      ("Disney+",     "The Mandalorian",       Seq(34521, 49678)),
      ("Apple TV",    "Ted Lasso",             Seq(34521, 49678)),
      ("Apple TV",    "The Morning Show",      Seq(20002, 34521))
    ).toDF("service", "show", "countries")
    
    val grouped = df.withColumn("country", explode($"countries")).
      groupBy("service", "country").agg(collect_list($"show").as("shows"))
    
    val service_information = grouped.groupBy("service").
      agg(collect_list($"country").as("c_list"), collect_list($"shows").as("s_list")).
      select($"service", map_from_arrays($"c_list", $"s_list").as("information"))
    
    service_information.show(false)
    // +--------+-----------------------------------------------------------------------------------------------------------------------------------+
    // |service |information                                                                                                                        |
    // +--------+-----------------------------------------------------------------------------------------------------------------------------------+
    // |Peacock |[34521 -> [The Office], 49678 -> [The Office]]                                                                                     |
    // |Disney+ |[20002 -> [WandaVision], 49678 -> [Marvel's 616, The Mandalorian], 34521 -> [WandaVision, The Mandalorian], 10001 -> [WandaVision]]|
    // |Apple TV|[34521 -> [Ted Lasso, The Morning Show], 49678 -> [Ted Lasso], 20002 -> [The Morning Show]]                                        |
    // |Netflix |[49678 -> [The Crown, The Queen's Gambit], 10001 -> [The Queen's Gambit], 34521 -> [The Crown, Bridgerton]]                        |
    // +--------+-----------------------------------------------------------------------------------------------------------------------------------+