Search code examples
pythonpysparkdatabricksazure-databricks

Efficient Merge Code in Pyspark / Databricks


I have a library built out for handling MERGE statements on Databricks delta tables. The code for these statements is pretty straightforward and for almost every table resembles the following:

def execute_call_data_pipeline(self, df_mapped_data: DataFrame, call_data_type: str = 'columns:mapped'):
        dt_call_data = get_delta_table(self.Spark, self.Catalog, self.Schema, 'call_data')
        dt_call_data.alias('old').merge(
                source=df_mapped_data.alias('new'),
                condition=expr('old.callDataId = new.callDataId')
        ).whenMatchedUpdate(set=
            {
                'callId': col('new.callId') if 'callId' in df_mapped_data.columns else col('old.callId'),
                'callDataMapId': col('new.callDataMapId') if 'callDataMapId' in df_mapped_data.columns else col('old.callDataMapId'),
                'callDataType': col('new.callDataType') if 'callDataType' in df_mapped_data.columns else col('old.callDataType'),
                'legacyColumn': col('new.legacyColumn') if 'legacyColumn' in df_mapped_data.columns else col('old.legacyColumn'),
                'dataValue': col('new.dataValue') if 'dataValue' in df_mapped_data.columns else col('old.dataValue'),
                'isEncrypted': col('new.isEncrypted') if 'isEncrypted' in df_mapped_data.columns else col('old.isEncrypted'),
                'silverUpdateOn': to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')
            }
        ).whenNotMatchedInsert(values=
            {
                'callId': col('new.callId'),
                'callDataMapId': col('new.callDataMapId') if 'callDataMapId' in df_mapped_data.columns else lit(None),
                'callDataType': col('new.callDataType') if 'callDataType' in df_mapped_data.columns else lit(call_data_type),
                'legacyColumn': col('new.legacyColumn') if 'legacyColumn' in df_mapped_data.columns else lit(None),
                'dataValue': col('new.dataValue'),
                'isEncrypted': col('new.isEncrypted') if 'isEncrypted' in df_mapped_data.columns else lit(False),
                'silverCreateOn': to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')
            }
        ).execute()

The code executes fine but is rather tedious to write out as I have to spell out every column on every table across 3 different catalogs (medallion approach lakehouse). I was looking to make this a bit more efficient to write, so I abstracted out the dictionary creation (since the rules were almost always the same for every column) and developed the following:

def get_column_value(df: DataFrame, column_name: str, table_alias: str = 'new', default_value=None) -> Column:
    default_value = default_value if default_value is Column else lit(default_value)
    return col(f'{table_alias}.{column_name}') if column_name in df.columns else default_value


def build_update_values(df: DataFrame, update_on_field, update_columns: list, new_alias: str = 'new', old_alias: str = 'old') -> dict:
    update_values = dict(map(lambda x: (x, get_column_value(df, x, new_alias, col(f'{old_alias}.{x}'))), update_columns))
    update_values.update({update_on_field: to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')})
    return update_values


def build_insert_values(df: DataFrame, create_on_field: str, update_columns: list, field_defaults: dict = None, table_alias: str = 'new', identity_col: str = None) -> dict:
    column_list = update_columns + [identity_col] if identity_col is not None else update_columns
    column_defaults = dict(map(lambda c: (c, None), update_columns))
    if field_defaults is not None:
        column_defaults = column_defaults | field_defaults
    insert_values = dict(map(lambda x: (x, get_column_value(df, x, table_alias, column_defaults[x])), column_defaults.keys()))
    insert_values.update({create_on_field: to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')})
    return insert_values


def build_update_columns(df: DataFrame, skip_columns: list) -> list:
    return list(set(df.columns) - set(skip_columns))

def execute_call_data_pipeline(self, df_mapped_data: DataFrame, call_data_type: str = 'columns:mapped'):
        dt_call_data = get_delta_table(self.Spark, self.Catalog, self.Schema, 'call_data')

        insert_defaults = dict([('callDataType', call_data_type),
                                ('isEncrypted', False)])
        update_columns = build_update_columns(df_mapped_data, self.__helper.SkipColumns + ['callDataId'])
        update_values = build_update_values(df_mapped_data, 'silverUpdateOn', update_columns)
        insert_values = build_insert_values(df_mapped_data, 'silverCreateOn', update_columns, field_defaults=insert_defaults)

        dt_call_data.alias('old').merge(
            source=df_mapped_data.alias('new'),
            condition=expr('(old.callDataId = new.callDataId) OR (old.callId = new.callId AND old.callDataType = new.callDataType AND old.legacyColumn = new.legacyColumn)')
        ).whenMatchedUpdate(set=update_values).whenNotMatchedInsert(values=insert_values).execute()

This makes the code MUCH easier to write but I'm noticing a SIGNIFICANT performance degradation. Jobs that used to execute in a handful of minutes now will take hours (or sometimes just hang). I'm admittedly not that comfortable with debugging the nitty gritty in Spark using the SparkUI, but the only thing that really seemed to jump out at me was some significant memory spill (~20GB although I wasn't seeing any out of memory errors).

I ran a few experiments to confirm it was the code changes that caused the issue and that is definitely the case, but I'm confused as to why. I'd love to find a way to keep the new code as it's much faster to write when onboarding new tables, but it's currently worthless given the performance. Can anyone point me towards what the problem may be or even where I might want to look in the SparkUI to identify the issue?

I am on Azure Databricks working with a Unity catalog.


Solution

  • Here, is the optimization you can do for your code.

    from pyspark.sql import DataFrame, Column
    from pyspark.sql.functions import col, lit, to_date
    from datetime import datetime, timezone
    
    def get_column_value(df: DataFrame, column_name: str, table_alias: str = 'new', default_value=None) -> Column:
        if column_name in df.columns:
            return col(f'{table_alias}.{column_name}')
        else:
            return default_value
    
    def build_update_values(df: DataFrame, update_on_field: str, update_columns: list, new_alias: str = 'new', old_alias: str = 'old') -> dict:
        update_values = {
            col_name: get_column_value(df, col_name, new_alias, col(f'{old_alias}.{col_name}'))
            for col_name in update_columns
        }
        update_values[update_on_field] = to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')
        return update_values
    
    def build_insert_values(df: DataFrame, create_on_field: str, update_columns: list, field_defaults: dict = None, table_alias: str = 'new', identity_col: str = None) -> dict:
        if identity_col: 
            update_columns.append(identity_col)
            
        column_defaults = {col: None for col in update_columns}
        if field_defaults:
            column_defaults.update(field_defaults)
    
        insert_values = {
            col_name: get_column_value(df, col_name, table_alias, column_defaults[col_name])
            for col_name in column_defaults.keys()
        }
        insert_values[create_on_field] = to_date(lit(datetime.now(timezone.utc)), 'yyyy-MM-dd HH:mm:ss.S')
        return insert_values
    

    Here, in get_column_value instead of doing dynamic evaluation simply check the condition and return the required columns.

    Next, for build_update_values and build_insert_values directly construct the dictionary using list comprehension and avoid map/lambda.

    In addition to this, do partition on key callDataId so that it reduces shuffling while merging.