Search code examples
apache-sparkdataframepysparkapache-spark-sqlhivecontext

How to divide a numerical columns in ranges and assign labels for each range in apache spark?


I have the following sparkdataframe:

id weekly_sale
1    40000
2    120000
3    135000
4    211000
5    215000
6    331000
7    337000

I need to see in which of the following intervals items in weekly_sale column fall:

under 100000
between 100000 and 200000
between 200000 and 300000
more than 300000

so my desired output would be like:

id weekly_sale  label
1    40000       under 100000    
2    120000      between 100000 and 200000
3    135000      between 100000 and 200000
4    211000      between 200000 and 300000
5    215000      between 200000 and 300000
6    331000      more than 300000
7    337000      more than 300000

any pyspark, spark.sql and Hive context implementation will help me.


Solution

  • Assuming ranges and labels are defined as follows:

    splits = [float("-inf"), 100000.0, 200000.0, 300000.0, float("inf")]
    labels = [
        "under 100000", "between 100000 and 200000", 
        "between 200000 and 300000", "more than 300000"]
    
    df = sc.parallelize([
        (1, 40000.0), (2, 120000.0), (3, 135000.0),
        (4, 211000.0), (5, 215000.0), (6, 331000.0),
        (7, 337000.0)
    ]).toDF(["id", "weekly_sale"])
    

    one possible approach is to use Bucketizer:

    from pyspark.ml.feature import Bucketizer
    from pyspark.sql.functions import array, col, lit
    
    bucketizer = Bucketizer(
        splits=splits, inputCol="weekly_sale", outputCol="split"
    )
    
    with_split = bucketizer.transform(df)
    

    and attach labels later:

    label_array = array(*(lit(label) for label in labels))
    
    with_split.withColumn(
        "label", label_array.getItem(col("split").cast("integer"))
    ).show(10, False)
    
    ## +---+-----------+-----+-------------------------+
    ## |id |weekly_sale|split|label                    |
    ## +---+-----------+-----+-------------------------+
    ## |1  |40000.0    |0.0  |under 100000             |
    ## |2  |120000.0   |1.0  |between 100000 and 200000|
    ## |3  |135000.0   |1.0  |between 100000 and 200000|
    ## |4  |211000.0   |2.0  |between 200000 and 300000|
    ## |5  |215000.0   |2.0  |between 200000 and 300000|
    ## |6  |331000.0   |3.0  |more than 300000         |
    ## |7  |337000.0   |3.0  |more than 300000         |
    ## +---+-----------+-----+-------------------------+
    

    There are of course different ways you can achieve the same goal. For example you can create a lookup table:

    from toolz import sliding_window
    from pyspark.sql.functions import broadcast
    
    mapping = [
        (lower, upper, label) for ((lower, upper), label)
        in zip(sliding_window(2, splits), labels)
    ]
    
    lookup_df =sc.parallelize(mapping).toDF(["lower", "upper", "label"])
    
    df.join(
        broadcast(lookup_df),
        (col("weekly_sale") >= col("lower")) & (col("weekly_sale") < col("upper"))
    ).drop("lower").drop("upper")
    

    or generate lookup expression:

    from functools import reduce
    from pyspark.sql.functions import when
    
    def in_range(c):
        def in_range_(acc, x):        
            lower, upper, label = x
            return when(
                (c >= lit(lower)) & (c < lit(upper)), lit(label)
            ).otherwise(acc)
        return in_range_
    
    label = reduce(in_range(col("weekly_sale")), mapping, lit(None))
    
    df.withColumn("label", label)
    

    The least efficient approach is an UDF.