Search code examples
scalaapache-sparkdistributed-computingexplode

Flattening rows in Spark along with existing columns


I have a dataset as below.

id1   k1, k2, k3, k4
id2   k1, k2
id3   k2, k3
id4   k4

I wish to count the number of rows in which each of my "k" is present along with the id's it is present for.

output :

k1  2    id1, id2
k2  3    id1, id2, id3
k3  2    id1, id3
k4  2    id1, id4

I have used explode and then group by on keys and I get the following output.

val newlines = sparkSession.read.textFile(s3Path)
.map(ke => {
            val split = ke.split("\t")
            (split(0), split(1).toString.split(", "))
    })

val myDF = newlines.withColumn("Key", explode($"_3")).groupBy(("Key"))
    .agg(count("Key"))

k1  2    
k2  3   
k3  2 
k4  2

Is there a way I can add id's as well ?


Solution

  • You can use spark inbuilt functions split,explode,agg!

    Example:

    scala> import org.apache.spark.sql.functions._
    scala> val df=Seq(("id1","k1,k2,k3,k4"),
                      ("id2","k1,k2"),
                      ("id3","k2,k3"),
                      ("id4","k4"))
                  .toDF("a","b")
    scala> df.selectExpr("a","explode(split(b,',')) as ex")
             .groupBy('ex) 
             .agg(concat_ws(",",collect_list('a)).alias("b"),
                count("*").alias("cnt"))
             .orderBy('ex)
             .show()
    

    Result:

    +---+-----------+---+
    | ex|          b|cnt|
    +---+-----------+---+
    | k1|    id1,id2|  2|
    | k2|id1,id2,id3|  3|
    | k3|    id1,id3|  2|
    | k4|    id1,id4|  2|
    +---+-----------+---+