Search code examples
pysparkazure-notebooks

PythonException -AssertionError raised within the Spark worker


For every row in input_table there should be created X amount of rows in output_table, where X=days in year (from StartDate)

Info field should contain Y amount of characters, where Y= X*2, if there are less, field should be padded with additional # characters.

In output_table AM and PM columns will be filled with Info characters in the correct order, so that each AM & PM fields will have exactly 1 character.

Here is the code:

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType, StringType, DateType, StructField, StructType, TimestampType, ArrayType

# Connection details for input table
url="..."
user="..."
password="..."
input_table="..."
output_table="..."

# Define schema for input table
input_schema = StructType([
    StructField("ID1", IntegerType(), True),
    StructField("ID2", IntegerType(), True),
    StructField("StartDate", TimestampType(), True),
    StructField("Info", StringType(), True),
    StructField("Extracted", TimestampType(), True)
])

# Define schema for output table
output_schema = StructType([
    StructField("ID1", IntegerType(), True),
    StructField("ID2", IntegerType(), True),
    StructField("Date", DateType(), True),
    StructField("AM", StringType(), True),
    StructField("PM", StringType(), True),
    StructField("CurrentYear", StringType(), True)
])

# Initialize SparkSession
spark = SparkSession.builder.getOrCreate()

# Register UDF for padding marks
pad_marks_udf = udf(lambda info, days: marks.ljust(days, '#')[:days], StringType())

# Register UDF for creating rows
create_rows_udf = udf(lambda start_date, marks, days: [(start_date + i, info[i], info[i + 1]) for i in range(0, days, 2)],
                     ArrayType(StructType([
                         StructField("Date", DateType(), True),
                         StructField("AM", StringType(), True),
                         StructField("PM", StringType(), True),
                     ])))

# Define function to pad marks and create rows
def process_row(row):
    id1 = row["ID1"]
    id2 = row["ID2"]
    start_date = row["StartDate"]
    info= row["info"]
    extracted = row["Extracted"]

    # Calculate number of days * 2
    days = (start_date.year % 4 == 0 and 366 or 365) * 2

    # Pad info
    padded_info = pad_info_udf(info, days)

    # Create rows
    rows = create_rows_udf(start_date, padded_info, days)

    # Prepare output rows
    output_rows = []
    for r in rows:
        date = r["Date"]
        am = r["AM"]
        pm = r["PM"]
        current_year = f"{current_year .year}/{current_year .year + 1}"

        output_rows.append((id1, id2, date, am, pm, current_year))

    return output_rows

# Load input table as DataFrame
df_input = spark.read \
    .format("jdbc") \
    .option("url", url) \
    .option("dbtable", input_table) \
    .option("user", user) \
    .option("password", password) \
    .schema(input_schema) \
    .load()

# Apply processing to input table
output_rows = df_input.rdd.flatMap(process_row)

# Create DataFrame from output rows
df_output = spark.createDataFrame(output_rows, output_schema)

# Write DataFrame to output table
df_output.write \
    .format("jdbc") \
    .option("url", url) \
    .option("user", user) \
    .option("password", password) \
    .option("dbtable", output_table) \
    .mode("append") \
    .save()

Similar code works in Python with no problems, but when translated to PySpark throws an AssertionError. It needs to do no modification in input_table and append output_table with modified rows from input_table.


Solution

  • So the reason is that code is not supposed to use Spark UDF in functions for RDDs. The plain functions should be used instead. Spark UDF can only be used in Spark SQL.

    The reason the code worked in the local machine is because in the local mode the executor is in the same JVM as the driver.