Search code examples
pythonmysqlapache-sparkpyspark

Vanishing data in PySpark: How to get it to stop vanishing?


I'm having a problem with my PySpark script. My task is basically

  1. Import data into PySpark from mySQL database.
  2. Do some transformations
  3. Write the transformed data back to the MySQL database

I can't show you the full code but I can show you an outline of what it looks like basically.

# load the SparkSession 
configs = get_configs_for_spark()
spark = get_spark_session(configs)

# open SSH tunnel with VM 
with SSHTunnelForwarder(**get_ssh_tunnel_args(configs)) as tunnel: 
    # grab unprocessed data
    df = get_raw_data(configs, tunnel, spark)

    # transform the data 
    df = transform_my_data_1(configs, df)

    # Get the number of rows in the dataframe
    num_rows = df.count()

    # Determine the number of subsets (each with 100 rows)
    subset_size = 100
    num_partitions = int(np.ceil(num_rows / subset_size))

    # Add a partition_id column to your DataFrame
    df = df.repartition(num_partitions)  # Specify the desired number of partitions

    # Add a partition_id column to your DataFrame
    df = df.withColumn("partition_id", spark_partition_id())

    # Group by partition_id and count the occurrences
    partition_counts = df.groupBy("partition_id").count()

    # create the Azure client to store data to blob storage
    container_client = create_azure_client(configs)  


    # Iterate over the subsets
    for i in range(num_partitions):
        subset_df = df.filter(df.partition_id == i)

        # Run transform_my_data_2 on the subset dataframe
        subset_df = transform_my_data_2(configs, container_client, subset_df)

        # Write the subset dataframe to MySQL DB
        write_data_to_mysql_db(configs, tunnel, subset_df.drop("partition_id"))

        # print the size of the data frame and expected size 
        # Print progress when partitions have been executed
        expected_df_count = partition_counts.filter(partition_counts.partition_id == i).select("count").first()[0]
        print(f"{i+1} / {num_partitions} partitions processed. || subset df Size: Expected({expected_df_count}) Actual({subset_df.count()}) ")

    
    # end spark session 
    spark.stop()

My function for writing the data to the MySQL db is

def write_data_to_mysql_db(configs, tunnel, df):
    database = configs["MySQL"]["database"]
    username = configs["MySQL"]["username"]
    password = configs["MySQL"]["password"]
    driver = configs["MySQL"]["driver"]
    table = configs["MySQL"]["table"]

    # Define the JDBC URL for the MySQL database
    url = f"jdbc:mysql://localhost:{tunnel.local_bind_port}/{database}"
    # Define the database properties for authentication
    properties = {
        "user": username,
        "password": password,
        "driver": driver
    }
    df.write.jdbc(url=url, table=table, mode="append", properties=properties)

My function for getting the information from the MySQL db is


def get_raw_data(configs, tunnel, spark_session):
    """
    gets raw data
    """
    database = configs["MySQL"]["database"]
    username = configs["MySQL"]["username"]
    password = configs["MySQL"]["password"]

    # Define the MySQL JDBC URL
    url = f'jdbc:mysql://localhost:{tunnel.local_bind_port}/{database}'
   
    # Create a dataframe by reading the data from the MySQL database
    df = spark_session.read.jdbc(url=url, table=transription_table, 
                         properties={"user": username, 
                                     "password": password,
                                     "characterSetResults": "utf8mb4"})


    # Execute custom SQL query to get data
    my_query = f"""QUERY REDACTED"""

    df = spark_session.read.jdbc(url=url, table=f"({my_query}) as temp", properties={"user": username, "password": password})


    return df

I checked the data after doing

# grab unprocessed data
df = get_raw_data(configs, tunnel, spark)

# transform the data 
df = transform_my_data_1(configs, df)

and the data looked exactly as I would have expected. However, when I run the full script, many of the rows that should be written back somehow vanish. Initially, I had my data frame organized into several partitions each with approximately the same size.

On the small set of data I tried, I initially started off with

+------------+-----+
|partition_id|count|
+------------+-----+
|           0|   90|
|           1|   90|
|           2|   90|
|           3|   90|
|           4|   90|
|           5|   89|
|           6|   89|
|           7|   90|
|           8|   90|
+------------+-----+

But then, after doing a few write operations, suddenly the counts started changing, implying there are fewer rows in my data frame and data is being lost.

+------------+-----+
|partition_id|count|
+------------+-----+
|           0|   87|
|           1|   87|
|           2|   87|
|           3|   87|
|           4|   87|
|           5|   86|
|           6|   86|
|           7|   87|
|           8|   87|
+------------+-----+

When I checked the MySQL database when the script finished, indeed some of the data was not present.

The fact that some of the data vanishes is very strange. It shouldn't be doing this because data frames are immutable in Spark.

To test the problem, I've tried various debugging strategies.

Debugging 1: Trying code on fake data

One that seemed highly effective was that, rather than using the data from the MySQL database, I generated a fake data set and then ran my code on that. This time, when I checked the amount of data in each partition, it remained consistently at

+------------+-----+
|partition_id|count|
+------------+-----+
|           0|  100|
|           1|  100|
|           2|  100|
|           3|  100|
|           4|  100|
|           5|  100|
|           6|  100|
|           7|  100|
|           8|  100|
|           9|  100|
+------------+-----+

So the point was, there was no observed loss of rows. This worked perfectly and allowed me to run my code just as I expected.

But obviously using fake data isn’t what I want. I need my code to work on the actual data.

Debugging 2: Using df.persist()

I also tried using df.persist(). This seems to resolve the problem but it makes the computation so slow that it's not feasible to run it.

# Iterate over the subsets
for i in range(num_partitions):
    if not df.is_cached():
        df.persist() 

    subset_df = df.filter(df.partition_id == i)

    # Run transform_my_data_2 on the subset dataframe
    subset_df = transform_my_data_2(configs, container_client, subset_df)

    # Write the subset dataframe to MySQL DB
    write_data_to_mysql_db(configs, tunnel, subset_df.drop("partition_id"))

    subset_df.unpersist()

    # print the size of the data frame and expected size 
    # Print progress when partitions have been executed
    expected_df_count = partition_counts.filter(partition_counts.partition_id == i).select("count").first()[0]
    print(f"{i+1} / {num_partitions} partitions processed. || subset df Size: Expected({expected_df_count}) Actual({subset_df.count()}) ")
df.unpersist()

This seemed to solve the issue insofar as the could would run and no data was lost, but it runs so slow that it isn’t feasible to scale.

Questions

I don't understand what this tells me about what's causing my problem.

  1. Why do my scripts work when I use data generated by my script instead of imported from the MySQL database?
  2. Why does using df.persist() help solve this vanishing data problem but make it so slow that it's not feasible to run except on tiny data frames (e.g. 1000 rows)?
  3. How can I run my code on the real data without having these rows vanish?

Solution

  • Just to clarify when you run any action (like count()) it means Spark will recompute the whole Dataframe including reading from the data source, so if you do read -> transformation_1 -> count(ACTION) -> transformation_2 -> count(ACTION) -> save(ACTION) then it will be executed multiple times:

    • read -> transformation_1 -> count
    • read -> transformation_1 -> transformation_2 -> count
    • read -> transformation_1 -> transformation_2 -> save

    Adding a persist() or cache() will save a copy of the dataframe so it won't calculate it again, so if you do read -> transformation_1 -> presist_1 -> count(ACTION) -> transformation_2 -> presist_2 -> count(ACTION) -> save(ACTION) then it will be executed as:

    • read -> transformation_1 -> presist_1 -> count
    • presist_1 -> transformation_2 -> presist_2 -> count
    • presist_2 -> save

    I don't have the full context of what these transformations are, but my guess is that doing multiple actions without presist/cache makes the dataframe to be recomputed each time, which could lead to different partitions, but the fake data could end up in the same partition each time.

    A few notes to fix the presist solution is that:

    • You don't need to add a column with the partition_id and do a filter, since this will do a full scan to the dataframe each time, so we can just apply the transformation_2 directly on the whole dataframe.
    • Doing the presist/unpresist in the for-loop makes it useless since the data is queried once.

    This could help to speed up the processing

    # open SSH tunnel with VM 
    with SSHTunnelForwarder(**get_ssh_tunnel_args(configs)) as tunnel:
        # grab unprocessed data
        df = get_raw_data(configs, tunnel, spark)
    
        # transform the data 
        df = transform_my_data_1(configs, df)
    
        # Presist before count
        df.presist()
    
        # Get the number of rows in the dataframe
        num_rows = df.count()
    
        # Determine the number of subsets (each with 100 rows)
        subset_size = 100
        num_partitions = int(np.ceil(num_rows / subset_size))
    
        # Add a partition_id column to your DataFrame
        df = df.repartition(num_partitions)  # Specify the desired number of partitions
    
        # Presist after repartition
        df.presist()
    
        # create the Azure client to store data to blob storage
        container_client = create_azure_client(configs)  
    
        df = transform_my_data_2(configs, container_client, df)
    
        # Write the subset dataframe to MySQL DB
        write_data_to_mysql_db(configs, tunnel, df)
    
        df.unpresist()
    
        # end spark session 
        spark.stop()