Search code examples
scalaapache-sparkapache-spark-sql

Duplicate each row on DataFrame based on an input Value


I'm beginner in Spark Scala. I want to duplicate each row in my DataFrame based on an input value.

My Input Dataset is something liks this

+------------+---------------------+---------------------+
|id          |currency             |value                |
+------------+---------------------+---------------------+
|1           |USD                  |10                   |
|1           |EUR                  |20                   |
|2           |USD                  |30                   |
|2           |EUR                  |40                   |
+------------+---------------------+---------------------+

And my Input Value is a Sequence. for example Seq("JPY"). I want an output like this:

+------------+---------------------+---------------------+
|id          |currency             |value                |
+------------+---------------------+---------------------+
|1           |USD                  |10                   |
|1           |EUR                  |20                   |

|2           |USD                  |30                   |
|2           |EUR                  |40                   |

|1           |JPY                  |10                   |
|1           |JPY                  |20                   |

|2           |JPY                  |30                   |
|2           |JPY                  |40                   |
+------------+---------------------+---------------------+

Could someOne please guide me how to solve this.


Solution

  • You can append seq and then use explode and union with original df

    import org.apache.spark.sql.functions._
    
    val inputData = Seq((1, "USD", 10), (1, "EUR", 20), (2, "USD", 30), (2, "EUR", 40))
    val inputSeq = Seq("JPY")
    
    val originalDf = inputData.toDF("id", "currency", "value")
    val originalDfWithSequence = originalDf.withColumn("currencies_seq", typedLit(inputSeq))
    val originalDfExploded = originalDfWithSequence.select(col("id"), explode(col("currencies_seq")).alias("currency"), col("value") )
    originalDf.union(originalDfExploded).show()
    

    Outpus is:

    +---+--------+-----+
    | id|currency|value|
    +---+--------+-----+
    |  1|     USD|   10|
    |  1|     EUR|   20|
    |  2|     USD|   30|
    |  2|     EUR|   40|
    |  1|     JPY|   10|
    |  1|     JPY|   20|
    |  2|     JPY|   30|
    |  2|     JPY|   40|
    +---+--------+-----+