Here is my dataset schema:
request_type | request_body
1 body A
2 body B
3 ...
4
5
6
..
32 body XXX
My brute force solution was to run 5 spark sql queries and then union results. Are there any better / smarter options to do it?
The sampleBy function does something similar, but you can only provide the expected fraction for each request_type
, not the number of expected row. Therefore implement some kind of stratified sampling manually.
Create a dataframe that contains a random amount between 5000 and 10000 of rows for each request_type
.
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
val spark = ...
import spark.implicits._
val m = 5000
val random = new scala.util.Random
val data: ListBuffer[Int] = ListBuffer()
for( i <- 1 to 32)
data ++= List.fill((m + random.nextInt(m)).toInt)(i)
val df = data.toDF("request_type")
val counts: Array[Row] = df.groupBy("request_type").count().collect()
val fractions: Map[Int, Double] = counts.map(e=>(e.getInt(0),e.getLong(1)))
.toList
.groupBy(e => if(e._1 <=4) e._1 else 5) //1 to 4 are separate buckets, the rest goes into 5
.mapValues(_.map(_._2))
.map{case (key, value) => (key, 100.0/value.sum)} //take 100 rows from each bucket
fraction
now contains a map with the fraction of rows that should be kept for each bucket.
For example:
Map(5 -> 4.696290869471292E-4, 1 -> 0.012753475322025252, 2 -> 0.01278281989006775, 3 -> 0.014253135689851768, 4 -> 0.010255358424776945)
This means that (in this random example) we want to keep 1.2 % of all rows with request_type = 1, 0.04 % of all rows with request_type >= 5 etc.
If the number of rows for each request_type is known before we can skip this step and set the the fractions-map directly.
Now add a column with random values between 0 and 1 to the dataframe and keep only the rows where this value is smaller or equal than the fraction from step 2.
val fcol: Column = fractions.foldLeft(lit(fractions(5)))
{case (acc, (list, value)) => when('request_type.equalTo(list), value).otherwise(acc)}
val sampledDf = df.withColumn("frac", fcol)
.withColumn("r", rand())
.filter('r <= 'frac)
sampledDf.groupBy("request_type").count().orderBy("request_type").show()
prints
+------------+-----+
|request_type|count|
+------------+-----+
| 1| 94|
| 2| 109|
| 3| 100|
| 4| 112|
| 5| 2|
| 6| 5|
| 7| 4|
| 8| 2|
| 9| 1|
| 10| 5|
...