Search code examples
apache-sparkpysparkapache-spark-sql

How to randomize different numbers for subgroup of rows pyspark


I have a pyspark dataframe. I need to randomize values taken from list for all rows within given condition. I did:

df = df.withColumn('rand_col', f.when(f.col('condition_col') == condition, random.choice(my_list)))

but the effect is, that it randomizes only one value and assigns it to all rows:

enter image description here

How can I randomize separately for each row?


Solution

  • You can:

    • use rand and floor from pyspark.sql.functions to create a random indexing column to index into your my_list
    • create a column in which the my_list value is repeated
    • index into that column using f.col

    It would look something like this:

    import pyspark.sql.functions as f
    
    my_list = [1, 2, 30]
    df = spark.createDataFrame(
        [
            (1, 0),
            (2, 1),
            (3, 1),
            (4, 0),
            (5, 1),
            (6, 1),
            (7, 0),
        ],
        ["id", "condition"]
    )
    
    df = df.withColumn('rand_index', f.when(f.col('condition') == 1, f.floor(f.rand() * len(my_list))))\
           .withColumn('my_list', f.array([f.lit(x) for x in my_list]))\
           .withColumn('rand_value', f.when(f.col('condition') == 1, f.col("my_list")[f.col("rand_index")]))
    
    df.show()
    +---+---------+----------+----------+----------+
    | id|condition|rand_index|   my_list|rand_value|
    +---+---------+----------+----------+----------+
    |  1|        0|      null|[1, 2, 30]|      null|
    |  2|        1|         0|[1, 2, 30]|         1|
    |  3|        1|         2|[1, 2, 30]|        30|
    |  4|        0|      null|[1, 2, 30]|      null|
    |  5|        1|         1|[1, 2, 30]|         2|
    |  6|        1|         2|[1, 2, 30]|        30|
    |  7|        0|      null|[1, 2, 30]|      null|
    +---+---------+----------+----------+----------+