Search code examples
pysparksampling

PySpark: sample size issue when using .sampleBy() to perform stratified sampling


I am using PySpark to implement stratified sampling. The column to be stratified is called cat, and it has three levels: A, B, C. In addition, I only want to sample 10% of the population. In the following code, fraction represents the percentage of each category level in the original dataset.

fraction = {
     "A": 0.90,
     "B": 0.09,
     "C": 0.01
}
per = 0.1
sample_size = round(df.count() * per, 0)
scale = per/sum(fraction.values())
fraction.update((k, v * scale) for (k, v) in fraction.items())

df_sample = df.sampleBy(col = 'cat', fractions = fraction, seed = 666)

The code above is adapted from the following link:

Size of sample with sampleBy in pyspark 2.4.0

However, for the result, I found the fraction is not the same as the original data. Instead, it is quite off. The result is as following:

A: 0.989
B: 0.0105
C: 1.393E-4

Any clue? Thank you for your help.


Solution

  • What is the total number of records of your data?

    As you mentioned that the fraction represents the percentage of each category level in the original dataset, and you would like to keep the same fraction after you do the sampling, you can skip the scale calculation actually:

    df = spark.createDataFrame(
        [('A', ) for _ in range(900)] + [('B', ) for _ in range(90)] + [('C', ) for _ in range(10)],
        schema=['cat']
    )
    df.groupBy('cat').count().orderBy('cat').show(100, False)
    +---+-----+
    |cat|count|
    +---+-----+
    |A  |900  |
    |B  |90   |
    |C  |10   |
    +---+-----+
    

    Just keep all category 0.1 in the fraction:

    df.sampleBy(func.col('cat'), fractions={'A': 0.1, 'B': 0.1, 'C': 0.1}, seed=666)\
        .groupBy('cat').count().orderBy('cat').show(100, False)
    +---+-----+
    |cat|count|
    +---+-----+
    |A  |95   |
    |B  |12   |
    |C  |1    |
    +---+-----+
    

    Although the ratio is not exactly the same of the original one, it shouldn't have such big difference compared to yours:

    95 / (95 + 12 + 1), 12 / (95 + 12 + 1), 1 / (95 + 12 + 1)
    >>> (0.8796296296296297, 0.1111111111111111, 0.009259259259259259)