Search code examples
apache-sparkpysparkrandomsamplingmultinomial

Choose from multinomial distribution


I have a series of values and a probability I want each those values sampled. Is there a PySpark method to sample from that distribution for each row? I know how to hard-code with a random number generator, but I want this method to be flexible for any number of assignment values and probabilities:

assignment_values = ["foo", "buzz", "boo"]
value_probabilities = [0.3, 0.3, 0.4]

Hard-coded method with random number generator:

from pyspark.sql import Row

data = [
    {"person": 1, "company": "5g"},
    {"person": 2, "company": "9s"},
    {"person": 3, "company": "1m"},
    {"person": 4, "company": "3l"},
    {"person": 5, "company": "2k"},
    {"person": 6, "company": "7c"},
    {"person": 7, "company": "3m"},
    {"person": 8, "company": "2p"},
    {"person": 9, "company": "4s"},
    {"person": 10, "company": "8y"},
]
df = spark.createDataFrame(Row(**x) for x in data)

(
    df
    .withColumn("rand", F.rand())
    .withColumn(
        "assignment", 
        F.when(F.col("rand") < F.lit(0.3), "foo")
        .when(F.col("rand") < F.lit(0.6), "buzz")
        .otherwise("boo")
    )
    .show()
)
+-------+------+-------------------+----------+
|company|person|               rand|assignment|
+-------+------+-------------------+----------+
|     5g|     1| 0.8020603266148111|       boo|
|     9s|     2| 0.1297179045352752|       foo|
|     1m|     3|0.05170251723736685|       foo|
|     3l|     4|0.07978240998283603|       foo|
|     2k|     5| 0.5931269297050258|      buzz|
|     7c|     6|0.44673560271164037|      buzz|
|     3m|     7| 0.1398027427612647|       foo|
|     2p|     8| 0.8281404801171598|       boo|
|     4s|     9|0.15568513681001817|       foo|
|     8y|    10| 0.6173220502731542|       boo|
+-------+------+-------------------+----------+

Solution

  • I think randomSplit may serve you. It randomly makes several dataframes out of your one nd puts them all into a list.

    df.randomSplit([0.3, 0.3, 0.4])
    

    You can also provide seed to it.

    You can join the dfs back together using reduce

    from pyspark.sql import functions as F
    from functools import reduce
    
    df = spark.createDataFrame(
        [(1, "5g"),
         (2, "9s"),
         (3, "1m"),
         (4, "3l"),
         (5, "2k"),
         (6, "7c"),
         (7, "3m"),
         (8, "2p"),
         (9, "4s"),
         (10, "8y")],
        ['person', 'company'])
    
    assignment_values = ["foo", "buzz", "boo"]
    value_probabilities = [0.3, 0.3, 0.4]
    
    dfs = df.randomSplit(value_probabilities, 5)
    dfs = [df.withColumn('assignment', F.lit(assignment_values[i])) for i, df in enumerate(dfs)]
    df = reduce(lambda a, b: a.union(b), dfs)
    
    df.show()
    # +------+-------+----------+
    # |person|company|assignment|
    # +------+-------+----------+
    # |     1|     5g|       foo|
    # |     2|     9s|       foo|
    # |     6|     7c|       foo|
    # |     4|     3l|      buzz|
    # |     5|     2k|      buzz|
    # |     8|     2p|      buzz|
    # |     3|     1m|       boo|
    # |     7|     3m|       boo|
    # |     9|     4s|       boo|
    # |    10|     8y|       boo|
    # +------+-------+----------+