Search code examples
pythonpysparkuser-defined-functions

Pyspark dataframe inside a udf


I have 2 pyspark dataframes: qnotes_df (2columns) and part_numbers_df (1column).In qnotes_df I have a column named 'LONG_TEXT'.I want to analyze this column and extract part numbers that may be in the text.These part numbers will be matched using the part_numbers_df.I already made the tokenization and all stuff, but when trying to compare every word with part_numbers_df, is not possible because you can't access a pyspark dataframe inside a udf.Any suggestions on how to do this?

Here is my code

# Define a UDF to extract part numbers from a text
def extract_part_numbers_udf(text):
    # Tokenize the text and filter part numbers
    tokens = nlp(text)
    matches = [str(token) for token in tokens if not token.is_punct and not token.is_space and part_numbers_df.filter(col("PART_NUMBER") == str(token)).count() > 0]
    return matches

# Register the UDF with ArrayType return type
udf_extract_part_numbers = udf(extract_part_numbers_udf, ArrayType(StringType()))

# Apply the UDF to create a new column
qnotes_df = qnotes_df.withColumn("REPLACEMENTS", udf_extract_part_numbers(qnotes_df["LONG_TEXT"]))

# Show the DataFrame with the new "REPLACEMENTS" column
qnotes_df.show(truncate=False)

Solution

  • try this instead:

    from pyspark.sql.functions import udf, broadcast
    from pyspark.sql.types import ArrayType, StringType
    from pyspark.ml.feature import Tokenizer
    import re
    
    part_numbers_set = set(part_numbers_df.rdd.map(lambda row: row[0]).collect())
    
    broadcast_part_numbers = spark.sparkContext.broadcast(part_numbers_set)
    
    def extract_part_numbers_udf(text):
        # Tokenize the text and filter part numbers
        tokens = re.split(r'\W+', text)  # simple tokenization using regex
        matches = [token for token in tokens if token in broadcast_part_numbers.value]
        return matches
    
    udf_extract_part_numbers = udf(extract_part_numbers_udf, ArrayType(StringType()))
    
    qnotes_df = qnotes_df.withColumn("REPLACEMENTS", udf_extract_part_numbers(qnotes_df["LONG_TEXT"]))
    
    qnotes_df.show(truncate=False)