Search code examples
pythondatabaseapache-sparkpysparkpsycopg2

How to pass psycopg2 cursor object to foreachPartition() in pyspark?


I'm getting following error

Traceback (most recent call last):
  File "/databricks/spark/python/pyspark/serializers.py", line 473, in dumps
    return cloudpickle.dumps(obj, pickle_protocol)
  File "/databricks/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/databricks/spark/python/pyspark/cloudpickle/cloudpickle_fast.py", line 563, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'psycopg2.extensions.cursor' object
PicklingError: Could not serialize object: TypeError: cannot pickle 'psycopg2.extensions.cursor' object

while running the below script

def get_connection():
    conn_props = brConnect.value
    print(conn_props)
    #extract value from broadcast variables
    database = conn_props.get("database")
    user = conn_props.get("user")
    pwd = conn_props.get("password")
    host = conn_props.get("host") 
    db_conn = psycopg2.connect(
                host = host,
                user = user,
                password = pwd,
                database = database,
                port = 5432
                )
    return db_conn
def process_partition_up(partition, db_cur):
    updated_rows = 0
    try:
        for row in partition:
            process_row(row, myq, db_cur)
    
    except Exception as e:
        print("Not connected")
   
    return updated_rows 
def update_final(df, db_cur):
    df.rdd.coalesce(2).foreachPartition(lambda x: process_partition_up(x, db_cur))
def etl_process():
    for id in ['003']:
        conn = get_connection()
        for t in ['email_table']:        
            query = f'''(select * from public.{t} where id= '{id}') as tab'''
            df_updated = load_data(query)
            if df_updated.count() > 0:
                q1 = insert_ops(df_updated, t) #assume this function returns a insert query
                query_props = q1
                sc = spark.sparkContext
                brConnectQ = sc.broadcast(query_props)
                db_conn = get_connection()
                db_cur = db_conn.cursor()
                update_final(df_updated, db_cur) 
        conn.commit()
        conn.close()

Explanation:

  • Here etl_process() internally calling get_connection() which returns a psycopg2 connection object. After that it's calling a update_final() which takes dataframe and psycopg2 cursor object as an arguments.
  • Now update_final() is calling process_partition_up() on each partition(df.rdd.coalesce(2).foreachPartition) which takes dataframe and psycopg2 cursor object as an arguments.
  • Here after passing psycopg2 cursor object to the process_partition_up(), I'm not getting cursor object rather I'm getting above error.

Can anyone help me out to resolve this error?

Thank you.


Solution

  • I think that you don't understand what's happening here.

    You are creating a database connection in your driver(etl_process), and then trying to ship that live connection from the driver, across your network to executor to do the work.(your lambda in foreachPartitions is executed on the executor.)

    That is what spark is telling you "cannot pickle 'psycopg2.extensions.cursor'". (It can't serialize your live connection to the database to ship it to an executor.)

    You need to call conn = get_connection() from inside process_partition_up this will initialize the connection to the database from inside the executor.(And any other book keeping you need to do.)

    FYI: The worst part that I want to call out is that this code will work on your local machine. This is because it's both the executor and the driver.