Search code examples
apache-sparksampling

Get even sample of data using spark


Here is my dataset schema:

request_type | request_body
1              body A
2              body B
3              ...
4
5
6
..
32            body XXX
  • I need to get 500 records in total.
  • there are 32 request types
  • request types should be bucketed this way:
    • 1,2,3,4 20% each, 5..32 20%
    • there should be 100 records of request_type 1, 2, 3, 4, 400 in total
    • last bucket should contain 100 records with request_type from 5 to 32

My brute force solution was to run 5 spark sql queries and then union results. Are there any better / smarter options to do it?


Solution

  • The sampleBy function does something similar, but you can only provide the expected fraction for each request_type, not the number of expected row. Therefore implement some kind of stratified sampling manually.

    1. Generate test data

    Create a dataframe that contains a random amount between 5000 and 10000 of rows for each request_type.

    import org.apache.spark.sql._
    import org.apache.spark.sql.functions._
    
    val spark = ...
    import spark.implicits._
    val m = 5000
    val random = new scala.util.Random
    val data: ListBuffer[Int] = ListBuffer()
    for( i <- 1 to 32)
      data ++= List.fill((m + random.nextInt(m)).toInt)(i)
    val df = data.toDF("request_type")
    

    2. Calculate the required fraction of rows for each bucket

    val counts: Array[Row] = df.groupBy("request_type").count().collect()
    val fractions: Map[Int, Double] = counts.map(e=>(e.getInt(0),e.getLong(1)))
      .toList
      .groupBy(e => if(e._1 <=4) e._1 else 5) //1 to 4 are separate buckets, the rest goes into 5
      .mapValues(_.map(_._2))
      .map{case (key, value) => (key, 100.0/value.sum)} //take 100 rows from each bucket
    

    fraction now contains a map with the fraction of rows that should be kept for each bucket.

    For example:

    Map(5 -> 4.696290869471292E-4, 1 -> 0.012753475322025252, 2 -> 0.01278281989006775, 3 -> 0.014253135689851768, 4 -> 0.010255358424776945)
    

    This means that (in this random example) we want to keep 1.2 % of all rows with request_type = 1, 0.04 % of all rows with request_type >= 5 etc.

    If the number of rows for each request_type is known before we can skip this step and set the the fractions-map directly.

    3. Select random rows

    Now add a column with random values between 0 and 1 to the dataframe and keep only the rows where this value is smaller or equal than the fraction from step 2.

    val fcol: Column = fractions.foldLeft(lit(fractions(5)))
      {case (acc, (list, value)) => when('request_type.equalTo(list), value).otherwise(acc)}
    val sampledDf = df.withColumn("frac", fcol)
      .withColumn("r", rand())
      .filter('r <= 'frac)
    

    4. Count the columns for each bucket in the sampled dataframe

    sampledDf.groupBy("request_type").count().orderBy("request_type").show()
    

    prints

    +------------+-----+
    |request_type|count|
    +------------+-----+
    |           1|   94|
    |           2|  109|
    |           3|  100|
    |           4|  112|
    |           5|    2|
    |           6|    5|
    |           7|    4|
    |           8|    2|
    |           9|    1|
    |          10|    5|
    ...