Search code examples
pythondataframeapache-sparkpysparkapache-spark-sql

How do I group data using python into multiple groups and assign values?


I want to create a function in python in which I can pass a dataframe as the first argument, percent as the second argument and the third argument will be the list of values

Function:

foo(df, perc, *args):

So, if I call this function like foo(population_df, 20, 'A','B', 'C'), then it will add a column to the dataframe and mark first 20 percent of records as A, then next 20 percent of records as B and then next 20 percent of records as C and remaining records as null as we can see in the below images.

Input DF:

enter image description here

Output DF:

enter image description here

Similarly, if I pass 12 percent, then it should mark first 12 percent of records as A, next 12 percent of records as B and next 12 percent of records as C and remaining records should be marked as null


Solution

  • You need a column to order the dataframe by to differentiate the first 20% from the second 20%, etc, so I added the ID column. I didn't use percent_rank transformation; because, its outputs didn't make sense, but if you wanna try it out, f.percent_rank().over(Window.orderBy(f.col(ordering_column))) can replace f.rank().over(Window.orderBy(f.col(ordering_column)))/record_count.

    import pyspark.sql.functions as f
    from pyspark.sql.types import *
    from pyspark.sql import SparkSession
    from pyspark.sql.window import Window
    
    spark = SparkSession.builder.appName("pyspark_playground").getOrCreate()
    
    def func(df, perc, ordering_column, column_values):
        record_count = df.count()
        df = (
            df
            .withColumn('percent_rank', 100*f.rank().over(Window.orderBy(f.col(ordering_column)))/record_count)
            .withColumn('col_values', f.lit(column_values))
            .withColumn('Col', f.element_at(f.col('col_values'), f.round(f.col('percent_rank')/perc).cast('int')))
            .drop('percent_rank', 'col_values')
        )
        
        return df
    
    df = spark.createDataFrame([
        (1, 'Tom', 'M'),
        (2, 'John', 'M'),
        (3, 'Saily', 'F'),
        (4, 'Noorie', 'F'),
        (5, 'Steve', 'M')
    ], ['ID', 'Name', 'Gender'])
    
    output = func(df, 20, 'ID', ['A', 'B', 'C'])
    output.show()
    

    and the output is

    +---+------+------+----+
    | ID|  Name|Gender| Col|
    +---+------+------+----+
    |  1|   Tom|     M|   A|
    |  2|  John|     M|   B|
    |  3| Saily|     F|   C|
    |  4|Noorie|     F|NULL|
    |  5| Steve|     M|NULL|
    +---+------+------+----+