Search code examples
pythonapache-sparkpyspark

Proper way to handle data from a generator using PySpark and writing it to parquet?


I have a data generator that returns an iterable. It fetches data from a specified date range. The total data it will fetch is close to a billion objects. My goal is to fetch all of this data, write it to a folder (local filesystem) and (already set up and working) use pyspark readstream to read these files and write it onto my database (cassandra). The scope of my question is limited to the fetching and writing the data onto the local filesystem.

I am trying to:

  1. Fetch data using a generator.
  2. Accumulate a batch of data
  3. When batch == batch_size, create a Spark Dataframe and,
  4. Write this Dataframe as .parquet format.

However, the issues I run into are segmentation faults (core dumped) and java error connection reset. I am very new to PySpark and am trying to educate myself on how to properly set it up and implement the workflow that I am going for. Specifically, I would appreciate help and feedback on the spark configuration and the primary error I keep getting consistenly:

Failed to write to data/data/polygon/trades/batch_99 on attempt 1: An error occurred while calling o1955.parquet.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 8 in stage 99.0 failed 1 times, most recent failure: Lost task 8.0 in stage 99.0 (TID 3176) (furkan-desktop executor driver): java.net.SocketException: Connection reset

Here is a screenshot of the spark UI:

sparkUIScreenshot

Current Implementation:

from datetime import datetime
import logging
import time
from dotenv import load_dotenv
import pandas as pd
import os
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    StructType,
    StructField,
    IntegerType,
    StringType,
    LongType,
    DoubleType,
    ArrayType,
)
import uuid
from polygon import RESTClient

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    filename="spark_logs/logfile.log",
    filemode="w",
)
from_date = datetime(2021, 3, 10)
to_date = datetime(2021, 3, 31)
load_dotenv()
client = RESTClient(os.getenv("POLYGON_API_KEY"))

# Create Spark session
spark = (
    SparkSession.builder.appName("TradeDataProcessing")
    .master("local[*]")
    .config("spark.driver.memory", "16g")
    .config("spark.executor.instances", "8")
    .config("spark.executor.memory", "16g")
    .config("spark.executor.memoryOverhead", "4g")
    .config("spark.executor.cores", "4")
    .config("spark.memory.offHeap.enabled", "true")
    .config("spark.memory.offHeap.size", "4g")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryoserializer.buffer.max", "512m")
    .config("spark.network.timeout", "800s")
    .config("spark.executor.heartbeatInterval", "20000ms")
    .config("spark.dynamicAllocation.enabled", "true")
    .config("spark.dynamicAllocation.minExecutors", "1")
    .config("spark.dynamicAllocation.maxExecutors", "8")
    .config("spark.dynamicAllocation.initialExecutors", "4")
    .getOrCreate()
)

# Define the schema corresponding to the JSON structure
schema = StructType(
    [
        StructField("exchange", IntegerType(), False),
        StructField("id", StringType(), False),
        StructField("participant_timestamp", LongType(), False),
        StructField("price", DoubleType(), False),
        StructField("size", DoubleType(), False),
        StructField("conditions", ArrayType(IntegerType()), True),
    ]
)


def ensure_directory_exists(path):
    """Ensure directory exists, create if it doesn't"""
    if not os.path.exists(path):
        os.makedirs(path)


# Convert dates to timestamps or use them directly based on your API requirements
from_timestamp = from_date.timestamp() * 1e9  # Adjusting for nanoseconds
to_timestamp = to_date.timestamp() * 1e9

# Initialize the trades iterator with the specified parameters
trades_iterator = client.list_trades(
    "X:BTC-USD",
    timestamp_gte=int(from_timestamp),
    timestamp_lte=int(to_timestamp),
    limit=1_000,
    sort="asc",
    order="asc",
)

trades = []
file_index = 0
output_dir = "data/data/polygon/trades"  # Output directory
ensure_directory_exists(output_dir)  # Make sure the output directory exists


def robust_write(df, path, max_retries=3, retry_delay=5):
    """Attempts to write a DataFrame to a path with retries on failure."""
    for attempt in range(max_retries):
        try:
            df.write.partitionBy("exchange").mode("append").parquet(path)
            print(f"Successfully written to {path}")
            return
        except Exception as e:
            logging.error(f"Failed to write to {path} on attempt {attempt+1}: {e}")
            time.sleep(retry_delay)  # Wait before retrying
    logging.critical(f"Failed to write to {path} after {max_retries} attempts.")


for trade in trades_iterator:
    trade_data = {
        "exchange": int(trade.exchange),
        "id": str(uuid.uuid4()),
        "participant_timestamp": trade.participant_timestamp,
        "price": float(trade.price),
        "size": float(trade.size),
        "conditions": trade.conditions if trade.conditions else [],
    }
    trades.append(trade_data)

    if len(trades) == 10000:
        df = spark.createDataFrame(trades, schema=schema)
        file_name = f"{output_dir}/batch_{file_index}"
        robust_write(df, file_name)

        trades = []
        file_index += 1

if trades:
    df = spark.createDataFrame(trades, schema=schema)
    file_name = f"{output_dir}/batch_{file_index}"
    robust_write(df, file_name)

Solution

  • This is not a perfect solution. But since streaming solution would be more suitable so providing it as an option.

    Adapted from socket example below

    https://github.com/abulbasar/pyspark-examples/blob/master/structured-streaming-socket.py

    https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html (search for 'socket' in this webpage)

    To figure out if processing is finished., just check for this line in the logs.

    WARN TextSocketMicroBatchStream: Stream closed by localhost:9979

    Just one caveat, the number of rows may not be exactly num_rows_per_batch , you can set a trigger timer to gauge how much time does it take for the iterator to generate 10000 rows.

    https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.streaming.DataStreamWriter.trigger.html

    from pyspark.sql import SparkSession
    from pyspark.sql.functions import col, from_json
    from pyspark.sql.types import StructType, StructField, StringType
    import json
    import threading
    import socket
    
    spark = SparkSession.builder \
        .appName("Example") \
        .getOrCreate()
    
    schema = StructType([
        StructField("column1", StringType()),
        StructField("column2", StringType()),
    ])
    
    
    def data_iterator():
        for i in range(100):
            yield {"column1": f"value1_{i}", "column2": f"value2_{i}"}
    
    
    host_given, port_given = "localhost", 9979
    
    
    def socket_server():
        host = host_given
        port = port_given
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.bind((host, port))
            s.listen(1)
            conn, addr = s.accept()
            with conn:
                for row in data_iterator():
                    data = json.dumps(row) + "\n"
                    conn.sendall(data.encode())
    
    
    
    
    
    server_thread = threading.Thread(target=socket_server)
    server_thread.start()
    
    df = spark.readStream \
        .format("socket") \
        .option("host", host_given) \
        .option("port", port_given) \
        .load() \
        .select(from_json(col("value"), schema).alias("data")) \
        .select("data.*")
    
    output_hello = "/path/to/data_output/parquet_so/"
    checkpoint_hello = "/path/to/data_output/parquet_checkpoint/"
    
    num_rows_per_batch = 20
    
    query = df.writeStream \
        .format("csv") \
        .option("path", output_hello) \
        .option("checkpointLocation", checkpoint_hello) \
        .option("maxRowsPerFile", num_rows_per_batch) \
        .start()
    
    
    
    query.awaitTermination()