Search code examples
apache-sparkpysparkuser-defined-functions

Pyspark udf to get random value - returns constant


I am trying to populate a Spark column with random string values according to a list and probabilities. It seems a nested function is needed from what I have read. I am trying the below and it works EXCEPT it returns the same sampled value for every row. For example, its all A or B or C. The function must be getting pickled in its state. How to fix to generate random draws?

def sim_strings(lst_choices, lst_probs):
    
    import random
    str_sampled = random.choices(lst_choices, weights = lst_probs)[0]
    
    def f(x):
        return(str_sampled)
    return (F.udf(f))


lst_choices_ = ['A', 'B', 'C']
lst_probs_ = [0.5, 0.45, 0.05]

df.withColumn('newcol', sim_strings(lst_choices = lst_choices_, lst_probs = lst_probs_)(F.col('existingcol'))).select('newcol').show(100)

Solution

  • Imo right now you are calling random.choices only once and then you are returning it in your f function.

    Not sure if this is what you want but i tried something like this and now random.choices is called for every row

    def sim_strings(lst_choices, lst_probs):
        
        import random
        
        def f(x):
            return(random.choices(lst_choices, weights = lst_probs)[0])
        return (F.udf(f))
    

    Looks like results are as expected:

    +------+
    |newcol|
    +------+
    |     B|
    |     B|
    |     A|
    |     A|
    |     B|
    |     B|
    |     A|
    |     A|
    |     A|
    |     C|
    |     A|
    |     A|
    |     B|
    |     B|