Search code examples
pythondataframeapache-sparkpysparkclosest

Find the closest value of each value in a column compared to another column in the same PySpark dataframe with exclusion criteria


I have two pyspark dataframes like shown below:

dfamr = spark.read.csv("/C:/Sushant Workspace/Tech/Self/pyspark 
datasets/amrrates.csv", header="true")
dfexclusion = spark.read.csv("/C:/Sushant Workspace/Tech/Self/pyspark 
datasets/exclusionrates.csv", header="true")
dfamr.show()

enter image description here

dfexclusion.show()

enter image description here

The task is to find the closest "ratecodeid" corresponding to "offer1" (and save it as "offer1CodeId") and also to "offer2" (and save it as "offer2CodeId"). ** In the offer1CodeId and offer2CodeId , I cannot be selecting the codes R4 and R5, even if they are the closest ones , instead I should be selecting the next closest rate from the ratecodeid column. the output table should be looking like below.

enter image description here

-Refer to the row5 (ratecodeid=R5), in the output below, offer1CodeId should be R4 because 5.4 is closest to 5.5, but instead it is R3 because R4 cannot be selected because it is in the exclusion rates. -Refer to the row7 (ratecodeid=R7), in the output below, offer2CodeId should be R4 because 5.6 is closest to 5.5, but instead it is R3 because R4 cannot be selected because it is in the exclusion rates. -Refer to the row6 (ratecodeid=R6), in the output below, offer1CodeId should be R5 because 5.85 is closest to 6, but instead it is R6 because R5 and the next closest R4 cannot be selected because it is in the exclusion rates.


Solution

  • Input dataframes:

    from pyspark.sql import functions as F, Window as W
    dfamr = spark.createDataFrame(
        [( 'R1', 4.0, 3.60, 3.2),
         ( 'R2', 4.5, 4.05, 3.6),
         ( 'R3', 5.0, 4.50, 4.0),
         ( 'R4', 5.5, 4.95, 4.4),
         ( 'R5', 6.0, 5.40, 4.8),
         ( 'R6', 6.5, 5.85, 5.2),
         ( 'R7', 7.0, 6.30, 5.6),
         ( 'R8', 7.5, 6.75, 6.0),
         ( 'R9', 8.0, 7.20, 6.4),
         ('R10', 8.5, 7.65, 6.8)],
        ['ratecodeid', 'weeklyrate', 'offer1', 'offer2'])
    
    dfexclusion = spark.createDataFrame([('R4',), ('R5',)], ['exclusionrate'])
    

    Script:

    exclusions = [r[0] for r in dfexclusion.select('exclusionrate').collect()]
    
    def closest(col):
        rates = F.collect_list(F.struct('weeklyrate', 'ratecodeid')).over(W.orderBy())
        return F.array_sort(F.transform(
            F.filter(rates, lambda x: ~x.ratecodeid.isin(exclusions)),
            lambda x: F.struct(
                F.abs(F.col(col) - x['weeklyrate']).alias('diff'),
                x['weeklyrate'].alias('weeklyrate'),
                x['ratecodeid'].alias('ratecodeid'),
            )
        ))[0]['ratecodeid'].alias(f'{col}Ratecode')
    df = dfamr.select('*', closest('offer1'), closest('offer2'))
    
    df.show()
    # +----------+----------+------+------+--------------+--------------+
    # |ratecodeid|weeklyrate|offer1|offer2|offer1Ratecode|offer2Ratecode|
    # +----------+----------+------+------+--------------+--------------+
    # |        R1|       4.0|   3.6|   3.2|            R1|            R1|
    # |        R2|       4.5|  4.05|   3.6|            R1|            R1|
    # |        R3|       5.0|   4.5|   4.0|            R2|            R1|
    # |        R4|       5.5|  4.95|   4.4|            R3|            R2|
    # |        R5|       6.0|   5.4|   4.8|            R3|            R3|
    # |        R6|       6.5|  5.85|   5.2|            R6|            R3|
    # |        R7|       7.0|   6.3|   5.6|            R6|            R3|
    # |        R8|       7.5|  6.75|   6.0|            R6|            R6|
    # |        R9|       8.0|   7.2|   6.4|            R7|            R6|
    # |       R10|       8.5|  7.65|   6.8|            R8|            R7|
    # +----------+----------+------+------+--------------+--------------+
    

    collect_list collects values into an array.
    filter removes values which are inside dfexclusion.
    transform adds the difference column into the array.
    array_sort sorts the array, so that the smallest difference would be the first element which we take afterwards.