Search code examples
scalaapache-sparksampling

Spark: How to perform undersampling on LabeledPoint?


I've got some unbalanced data in my LabeledPoint. what I want to do is select all positives and n times more negatives (randomly). For example if I have a 100 positives and 30000 negatives, I want to create new LabeledPoint with all 100 positives and 300 negatives (n=3).

And in real scenario I don't how many positives and negatives I have on the beginning.


Solution

  • Presumably your data is a RDD[LabeledPoint]. You can do something like the following:

    val pos = rdd.filter(_.label==1)
    val numPos=pos.count()
    val neg = rdd.filter(_.label==0).takeSample(false, numPos*3)
    val undersample = pos.union(neg)
    

    You can find the docs for takeSample, filter, and union here.