Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-datasetdata-transform

How do I efficiently map keys from one dataset based on values from other dataset


Assuming data frame 1 represents target country and list of source countries and data frame 2 represents the availability for all the countries, find all the pairs from data frame 1 where target country mapping is TRUE and source country mapping is FALSE:

Dataframe 1 (targetId, sourceId):
USA: China, Russia, India, Japan
China: USA, Russia, India
Russia: USA, Japan

Dataframe 2 (id, available):
USA: true
China: false
Russia: true
India: false
Japan: true

Result Dataset should look like:
(USA, China),
(USA, India)

My idea is to first explode the data set1, create new data frame (say, tempDF), add 2 new columns to it: targetAvailable, sourceAvailable and finally filter for targetAvailable = false and sourceAvailable = true to get the desired result data frame.

Below is the snippet of my code:

 val sourceDF = sourceData.toDF("targetId", "sourceId")
 val mappingDF = mappingData.toDF("id", "available")
 val tempDF = sourceDF.select(col("targetId"), 
                explode(col("sourceId")).as("source_id_split"))

 val resultDF = tempDF.select("targetId")
         .withColumn("targetAvailable", isAvailable(tempDF.col("targetId")))
         .withColumn("sourceAvailable", isAvailable(tempDF.col("source_id_split")))


 /*resultDF.select("targetId", "sourceId").
  filter(col("targetAvailable") === "true" and col("sourceAvailable") 
  === "false").show()*/


// udf to find the availability value for the given id from the mapping table
val isAvailable = udf((searchId: String) => {
val rows = mappingDF.select("available")
          .filter(col("id") === searchId).collect()

if (rows(0)(0).toString.equals("true")) "true" else "false"  })

Calling isAvailable UDF while calculating the resultDF throws me some weird exception. Am I doing something wrong? is there a better / simpler way to do this?


Solution

  • In your UDF, you are making reference to another dataframe, which is not possible, hence the "weird" exception you obtain.

    You want to filter one dataframe based on values contained in another. What you need to do is a join on the id columns. Two joins actually in your case, one for the targets, one for the sources.

    The idea to use explode however is very good. Here is a way to achieve what you want:

    // generating data, please provide this code next time ;-)
    val sourceDF = Seq("USA" ->  Seq("China", "Russia", "India", "Japan"),
                       "China" -> Seq("USA", "Russia", "India"),
                       "Russia" -> Seq("USA", "Japan"))
                   .toDF("targetId", "sourceId")
    val mappingDF = Seq("USA" -> true, "China" -> false,
                        "Russia" -> true, "India" -> false,
                        "Japan" -> true)
                   .toDF("id", "available")
    
    sourceDF
        // we can filter available targets before exploding.
        // let's do it to be more efficient.
        .join(mappingDF.withColumnRenamed("id", "targetId"), Seq("targetId"))
        .where('available)
        // exploding the sources
        .select('targetId, explode('sourceId) as "sourceId")
        // then we keep only non available sources
        .join(mappingDF.withColumnRenamed("id", "sourceId"), Seq("sourceId"))
        .where(! 'available)
        .select("targetId", "sourceId")
        .show(false)
    

    which yields

    +--------+--------+
    |targetId|sourceId|
    +--------+--------+
    |USA     |China   |
    |USA     |India   |
    +--------+--------+