Search code examples
scalacsvfor-loopselectapache-spark-sql

Scala - split csv file into multiple csv file based on column key in another csv file


I am trying to create multiple csv file from and one large csv file based on reading another csv file which has key/ partial value to break csv file into many

  1. key.csv
model_name
Car
bike
Bus
Auto
  1. Input.csv
ID car_type car_name car_PRESENT bike_type bike_name bus_type bus_name auto_type
1 Honda city YES yamaha fz school ford Manual
2 TATA punch YES hero xtreme public Ashok Gas

so i want to read key.csv file and create/ break Input.csv file based on key.csv file value like...

car.csv

ID car_type car_name
1 Honda city
2 TATA punch

bike.csv

ID bike_type bike_name
1 yamaha fz
2 hero xtreme

and same goes for bus.csv and auto.csv

to get this results i tried using below:

import spark.implicits._; 
import org.apache.spark.sql.types.{StructType,StructField,StringType,IntegerType};
import org.apache.spark.sql.Row;

val input_file = "/input.csv"
val mod_in_path = "/key.csv"

val df_input_mod=spark.read.format("csv").option("header","true").option("delimiter","|").load(mod_in_path)

val model_names = df_input_mod.select("model_name")

val df_input=spark.read.format("csv").option("header","true").option("delimiter","|").load(input_file)
val all_cols = df_input.columns

val party_col = all_cols.filter(_.contains("id"))

for( mname <- model_names){
println(mname)

var mname_col = all_cols.filter(_.contains(mname.mkString(""))).filter(! _.contains("PRESENT")).mkString(",")

println(mname_col)

var final_col = party_col.mkString("").concat(",").concat(mname_col)

println(final_col)

var colName = Seq(final_col)

var columnsAll=colName.map(m=>col(m))

#var final_val = df_input.select(final_col.split(",").map(_.toString).map(col): _*)

var final_val = df_input.select(columnsAll:_*)
     final_val.repartition(1).write.mode("overwrite").option("delimiter", "|").option("header",true).csv("/output/"+mname)

println("output file created for "+mname )
}

getting below error when using map inside loop.

ERROR Executor: Exception in task 0.0 in stage 2.0 (TID 2)
org.apache.spark.SparkException:  Dataset transformations and actions can only be invoked by the driver, not inside of other Dataset transformations; for example, dataset1.map(x => dataset2.values.count() * x) is invalid because the values transformation and count action cannot be performed inside of the dataset1.map transformation. For more information, see SPARK-28702.
        at org.apache.spark.sql.errors.QueryExecutionErrors$.transformationsAndActionsNotInvokedByDriverError(QueryExecutionErrors.scala:1967)
        at org.apache.spark.sql.Dataset.sparkSession$lzycompute(Dataset.scala:198)
        at org.apache.spark.sql.Dataset.sparkSession(Dataset.scala:196)
        at org.apache.spark.sql.Dataset.withPlan(Dataset.scala:3887)
        at org.apache.spark.sql.Dataset.select(Dataset.scala:1519)
        at $line27.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.$anonfun$res0$1(<console>:40)
        at $line27.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.$anonfun$res0$1$adapted(<console>:32)
        at scala.collection.Iterator.foreach(Iterator.scala:943)

breaking large csv to many


Solution

  • If keyDf is not a big dataframe you can just do a collect and iterate over the keys:

    import spark.implicits._
    
    val keyDf = spark.sparkContext.parallelize(Seq("Car", "bike", "Bus", "Auto")).toDF("model_name")
    
    val data = Seq(
      (1, "Honda", "city", "YES", "yamaha", "fz", "school", "ford", "Manual"),
      (2, "TATA", "punch", "YES", "hero", "xtreme", "public", "Ashok", "Gas")
    )
    val InputDf = spark.sparkContext.parallelize(data).toDF("ID", "car_type", "car_name", "car_PRESENT", "bike_type", "bike_name", "bus_type", "bus_name", "auto_type")
    
    keyDf.distinct().collect().map(row => row.getString(0).toLowerCase()).foreach(r => {
      if (List(s"${r}_type", s"${r}_name").forall(InputDf.columns.map(_.toLowerCase()).contains)) {
        val df = InputDf.select("ID", s"${r}_type", s"${r}_name")
        df.show(false)
        df.write.csv(s"path/.../$r.csv")
      }
    })