Search code examples
scalaapache-sparkapache-spark-sqlapache-spark-mllibapache-spark-ml

How to split the spark dataframe into 2 using ratio given in terms of months and the unix epoch column?


I wanted to split the spark dataframe into 2 using ratio given in terms of months and the unix epoch column-

sample dataframe is as below-

unixepoch
---------
1539754800
1539754800
1539931200
1539927600
1539927600
1539931200
1539931200
1539931200
1539927600
1540014000
1540014000
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400
1540190400

strategy of splitting-

if total months of data given is say 30 months and splittingRatio is say 0.6 then expected dataframe 1 should have: 30 * 0.6 = 18 months of data and expected dataframe 1 should have: 30 * 0.4 = 12 months of data

EDIT-1

most of the answers are given by considering splitting ratio for number of records i.e. if total records count = 100 and split ratio = 0.6 then split1DF~=60 records and split2DF~=40 records. To be more clear, this is not i am looking for. Here splitting ratio is given for month which can be calculated by the given epoch unix timestamp column from the above sample dataframe. Suppose above epoch column is some distibution of 30 months then I want first 18 months epoch in the dataframe 1 and last 12 months epoch rows in the second dataframe. you can consider this as split the dataframe for timeseries data in spark.

EDIT-2

if the data is given for July, 2018 to May, 2019=10 months data, then split1(0.6=first 6 months)= (July, 2018, Jan,2019 ) and split2(0.4=last 4 months)= (Feb,2019, May, 2019 ). randomized picking shouldn't be there.


Solution

  • I have divided data based on months and then days if the data is given for 1 month.

    I prefer this method since this answer is not dependent on the windowing function. Other answer given here uses Window without partitionBy which degrades the performance seriously as data shuffles to one executor.

    1. splitting method given a train ratio in terms of months

     val EPOCH = "epoch"
        def splitTrainTest(inputDF: DataFrame,
                           trainRatio: Double): (DataFrame, DataFrame) = {
          require(trainRatio >= 0 && trainRatio <= 0.9, s"trainRatio must between 0 and 0.9, found : $trainRatio")
    
          def extractDateCols(tuples: (String, Column)*): DataFrame = {
            tuples.foldLeft(inputDF) {
              case (df, (dateColPrefix, dateColumn)) =>
                df
                  .withColumn(s"${dateColPrefix}_month", month(from_unixtime(dateColumn))) // month
                  .withColumn(s"${dateColPrefix}_dayofmonth", dayofmonth(from_unixtime(dateColumn))) // dayofmonth
                  .withColumn(s"${dateColPrefix}_year", year(from_unixtime(dateColumn))) // year
            }
          }
    
          val extractDF = extractDateCols((EPOCH, inputDF(EPOCH)))
    
          // derive min/max(yyyy-MM)
          val yearCol = s"${EPOCH}_year"
          val monthCol = s"${EPOCH}_month"
          val dayCol = s"${EPOCH}_dayofmonth"
          val SPLIT = "split"
          require(trainRatio >= 0 && trainRatio <= 0.9, s"trainRatio must between 0 and 0.9, found : $trainRatio")
    
          // derive min/max(yyyy-MM)
          //    val yearCol = PLANNED_START_YEAR
          //    val monthCol = PLANNED_START_MONTH
          val dateCol = to_date(date_format(
            concat_ws("-", Seq(yearCol, monthCol).map(col): _*), "yyyy-MM-01"))
    
          val minMaxDF = extractDF.agg(max(dateCol).as("max_date"), min(dateCol).as("min_date"))
          val min_max_date = minMaxDF.head()
          import java.sql.{Date => SqlDate}
          val minDate = min_max_date.getAs[SqlDate]("min_date")
          val maxDate = min_max_date.getAs[SqlDate]("max_date")
    
          println(s"Min Date Found: $minDate")
          println(s"Max Date Found: $maxDate")
    
          // Get the total months for which the data exist
          val totalMonths = (maxDate.toLocalDate.getYear - minDate.toLocalDate.getYear) * 12 +
            maxDate.toLocalDate.getMonthValue - minDate.toLocalDate.getMonthValue
          println(s"Total Months of data found for is $totalMonths months")
    
          // difference starts with 0
          val splitDF = extractDF.withColumn(SPLIT, round(months_between(dateCol, to_date(lit(minDate)))).cast(DataTypes.IntegerType))
    
          val (trainDF, testDF) = totalMonths match {
            // data is provided for more than a month
            case tm if tm > 0 =>
              val trainMonths = Math.round(totalMonths * trainRatio)
              println(s"Data considered for training is < $trainMonths months")
              println(s"Data considered for testing is >= $trainMonths months")
              (splitDF.filter(col(SPLIT) < trainMonths), splitDF.filter(col(SPLIT) >= trainMonths))
    
            // data is provided for a month, split based on the total records  in terms of days
            case tm if tm == 0 =>
              //        val dayCol = PLANNED_START_DAYOFMONTH
              val splitDF1 = splitDF.withColumn(SPLIT,
                datediff(date_format(
                  concat_ws("-", Seq(yearCol, monthCol, dayCol).map(col): _*), "yyyy-MM-dd"), lit(minDate))
              )
              // Get the total days for which the data exist
              val todalDays = splitDF1.select(max(SPLIT).as("total_days")).head.getAs[Int]("total_days")
              if (todalDays <= 1) {
                throw new RuntimeException(s"Insufficient data provided for training, Data found for $todalDays days but " +
                  s"$todalDays > 1 required")
              }
              println(s"Total Days of data found is $todalDays days")
    
              val trainDays = Math.round(todalDays * trainRatio)
              (splitDF1.filter(col(SPLIT) < trainDays), splitDF1.filter(col(SPLIT) >= trainDays))
    
            // data should be there
            case default => throw new RuntimeException(s"Insufficient data provided for training, Data found for $totalMonths " +
              s"months but $totalMonths >= 1 required")
          }
          (trainDF.cache(), testDF.cache())
        }
    

    2. Test using the data from multiple months across years

     //  call methods
        val implicits = sqlContext.sparkSession.implicits
        import implicits._
        val monthData = sc.parallelize(Seq(
          1539754800,
          1539754800,
          1539931200,
          1539927600,
          1539927600,
          1539931200,
          1539931200,
          1539931200,
          1539927600,
          1540449600,
          1540449600,
          1540536000,
          1540536000,
          1540536000,
          1540424400,
          1540424400,
          1540618800,
          1540618800,
          1545979320,
          1546062120,
          1545892920,
          1545892920,
          1545892920,
          1545201720,
          1545892920,
          1545892920
        )).toDF(EPOCH)
    
        val (split1, split2) = splitTrainTest(monthData, 0.6)
        split1.show(false)
        split2.show(false)
    
        /**
          * Min Date Found: 2018-10-01
          * Max Date Found: 2018-12-01
          * Total Months of data found for is 2 months
          * Data considered for training is < 1 months
          * Data considered for testing is >= 1 months
          * +----------+-----------+----------------+----------+-----+
          * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
          * +----------+-----------+----------------+----------+-----+
          * |1539754800|10         |17              |2018      |0    |
          * |1539754800|10         |17              |2018      |0    |
          * |1539931200|10         |19              |2018      |0    |
          * |1539927600|10         |19              |2018      |0    |
          * |1539927600|10         |19              |2018      |0    |
          * |1539931200|10         |19              |2018      |0    |
          * |1539931200|10         |19              |2018      |0    |
          * |1539931200|10         |19              |2018      |0    |
          * |1539927600|10         |19              |2018      |0    |
          * |1540449600|10         |25              |2018      |0    |
          * |1540449600|10         |25              |2018      |0    |
          * |1540536000|10         |26              |2018      |0    |
          * |1540536000|10         |26              |2018      |0    |
          * |1540536000|10         |26              |2018      |0    |
          * |1540424400|10         |25              |2018      |0    |
          * |1540424400|10         |25              |2018      |0    |
          * |1540618800|10         |27              |2018      |0    |
          * |1540618800|10         |27              |2018      |0    |
          * +----------+-----------+----------------+----------+-----+
          *
          * +----------+-----------+----------------+----------+-----+
          * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
          * +----------+-----------+----------------+----------+-----+
          * |1545979320|12         |28              |2018      |2    |
          * |1546062120|12         |29              |2018      |2    |
          * |1545892920|12         |27              |2018      |2    |
          * |1545892920|12         |27              |2018      |2    |
          * |1545892920|12         |27              |2018      |2    |
          * |1545201720|12         |19              |2018      |2    |
          * |1545892920|12         |27              |2018      |2    |
          * |1545892920|12         |27              |2018      |2    |
          * +----------+-----------+----------------+----------+-----+
          */
    

    3. Test using one month of data from a year

     val oneMonthData = sc.parallelize(Seq(
          1589514575, //  Friday, May 15, 2020 3:49:35 AM
          1589600975, // Saturday, May 16, 2020 3:49:35 AM
          1589946575, // Wednesday, May 20, 2020 3:49:35 AM
          1590378575, // Monday, May 25, 2020 3:49:35 AM
          1590464975, // Tuesday, May 26, 2020 3:49:35 AM
          1590470135 // Tuesday, May 26, 2020 5:15:35 AM
        )).toDF(EPOCH)
    
        val (split3, split4) = splitTrainTest(oneMonthData, 0.6)
        split3.show(false)
        split4.show(false)
    
        /**
          * Min Date Found: 2020-05-01
          * Max Date Found: 2020-05-01
          * Total Months of data found for is 0 months
          * Total Days of data found is 25 days
          * +----------+-----------+----------------+----------+-----+
          * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
          * +----------+-----------+----------------+----------+-----+
          * |1589514575|5          |15              |2020      |14   |
          * +----------+-----------+----------------+----------+-----+
          *
          * +----------+-----------+----------------+----------+-----+
          * |epoch     |epoch_month|epoch_dayofmonth|epoch_year|split|
          * +----------+-----------+----------------+----------+-----+
          * |1589600975|5          |16              |2020      |15   |
          * |1589946575|5          |20              |2020      |19   |
          * |1590378575|5          |25              |2020      |24   |
          * |1590464975|5          |26              |2020      |25   |
          * |1590470135|5          |26              |2020      |25   |
          * +----------+-----------+----------------+----------+-----+
          */