Search code examples
loopspysparkuser-defined-functionsphonetics

Looping through an array of float column in a Pyspark DataFrame to find which values pass through a condition


I am using Pyspark 2.4.0.

I have a float list as such :

[0.6067762380157111,
 0.4595708424660512,
 0.20093402090021173,
 0.5736288504883545,
 0.46593043507957116,
 0.5734057882715504,
 0.6940067723754003,
 0.30921836906829625,
 0.768595041322314...]

that are the result of phonetic transcription, from a list of words.

I have multiple Pyspark DataFrames as such :

+------------------------+--------------------------+---------------------------------------------+
|com                     |split                     |phoned                                       |
+------------------------+--------------------------+---------------------------------------------+
|sans option             |[sans, option]            |[0.6832268970698724, 0.6248699945979845]     |
|                        |[]                        |[]                                           |
|fermer l hiver          |[fermer, l, hiver]        |[0.3154196179245309, 0.5, 0.3828829842720629]|
+------------------------+--------------------------+---------------------------------------------+

The idea would be to find, for every value in "phoned" in all of my Pyspark DataFrames if one value in the array is "close" to a value in the list (ie with a given threshold, that I can change).

So for every DataFrame, to "loop" through all the values in the column "phoned", loop through the given array, get the difference between the value and every element of the list and when a difference is below the threshold, get the given value in another column. If possible, I would like to get all the values that are the given threshold.

For example, if my list of words would be this one :

["sans",
 "opssion",
 "test",
 "ferme",...]

The phonetic transcription version would be :

[0.6832268970698724,
0.625416705239052,
0.7390120210368145,
0.3154165838771569,...]

And the result I would like, for a threshold of 0.01,

+------------------------+--------------------------+---------------------------------------------+--------------+
|com                     |split                     |phoned                                       |result        |
+------------------------+--------------------------+---------------------------------------------+--------------+
|sans option             |[sans, option]            |[0.6832268970698724, 0.6248699945979845]     |[sans, opssion]
|                        |[]                        |[]                                           |[]            |
|fermer l hiver          |[fermer, l, hiver]        |[0.3154196179245309, 0.5, 0.3828829842720629]|[ferme]       |
+------------------------+--------------------------+---------------------------------------------+--------------+

I've tried my way around with some UDF, but I didn't find a solution that gives this kind of result. I have a dozen of DataFrames, some with multiple columns "com" and with ~1 million record for some of them, so I can't deal with it in Pandas.

Thanks !


Solution

  • Your DataFrame(df_1)

    +--------------+------------------+---------------------------------------------+
    |com           |split             |phoned                                       |
    +--------------+------------------+---------------------------------------------+
    |sans option   |[sans, option]    |[0.6832268970698724, 0.6248699945979845]     |
    |              |[]                |[]                                           |
    |fermer l hiver|[fermer, l, hiver]|[0.3154196179245309, 0.5, 0.3828829842720629]|
    +--------------+------------------+---------------------------------------------+
    

    Your constants:

    threshold = 0.01
    
    match = [0.6067762380157111,0.4595708424660512,
             0.20093402090021173,0.5736288504883545,
             0.46593043507957116,0.5734057882715504,
             0.6940067723754003,0.30921836906829625,0.768595041322314]
    

    Created UDF

    def between_threshold(element, threshold): 
        if element is None:
            return False
        return any([abs(element - match_element) <= threshold for match_element in match])
    

    Importing necessary packages:

    from pyspark.sql.functions import lit, col, udf, row_number, posexplode_outer, collect_list
    
    1. Create an unique row identifier
    row_window_spec = Window.orderBy(lit(1))
    
    df_2 = df_1.withColumn("row_num", row_number().over(row_window_spec))
    
    1. Find if the elements in phoned is within the threshold or not UDF
    df_3 = df_2.select("row_num", posexplode_outer("split").alias("index", "split"), "phoned")
    
    df_4 = df_3.withColumn("phoned", col("phoned")[col("index")])
    
    df_5 = df_4.withColumn("between_threshold", udf(between_threshold_udf("phoned", lit(threshold))))
    
    1. Filter and Join with the original DataFrame
    df_6 = df_5.filter(col("between_threshold") == True) \
            .groupBy("row_num") \
            .agg(collect_list("split").alias("result"))
    
    df_2.join(df_6, "row_num", "left").drop("row_num").show(truncate=False)
    

    Output

    +--------------+------------------+---------------------------------------------+--------+
    |com           |split             |phoned                                       |result  |
    +--------------+------------------+---------------------------------------------+--------+
    |sans option   |[sans, option]    |[0.6832268970698724, 0.6248699945979845]     |null    |
    |              |[]                |[]                                           |null    |
    |fermer l hiver|[fermer, l, hiver]|[0.3154196179245309, 0.5, 0.3828829842720629]|[fermer]|
    +--------------+------------------+---------------------------------------------+--------+
    

    Note: In this dataset only fermer matches with the threshold 0.01