Search code examples
listapache-sparkpysparkcomparemultiple-columns

How can I compare pairs of columns in a PySpark dataframe and number of records changed?


I have a situation where I need to compare multiple pairs of columns (the number of pairs will vary and can come from a list as shown in below code snippet) and get 1/0 flag for match/mismatch respectively. Eventually use this to identify the number of records/rows with mismatch and % records mismatched

NONKEYCOLS= ['Marks', 'Qualification']

The first image is source df and second image is expected df.

Source DF

[Expected Results2

enter image description here

Since this is happening for multiple pairs on a loop, it is very slow for about a billion records. Need help with something efficient.

I have the below code but the part that calculates change records is taking long time.

  for ind,cols in enumerate(NONKEYCOLS):
        print(ind)
        print(cols)
        globals()['new_dataset' + '_char_changes_tmp']=globals()['new_dataset' + '_char_changes_tmp']\
                                                       .withColumn("records_changed" + str(ind),\
                                                       F.sum(col("records_ch_flag_" + str(ind)))\
                                                       .over(w1))
    
        globals()['new_dataset' + '_char_changes_tmp']=globals()['new_dataset' + '_char_changes_tmp']\
                                                       .withColumn("records_changed" + str(ind),\
                                                       F.sum(col("records_ch_flag_" + str(ind)))\
                                                       .over(w1))
        globals()['new_dataset' + '_char_changes_tmp']=globals()['new_dataset' + '_char_changes_tmp']\
                                                       .withColumn("records_changed_cnt" + str(ind),\
                                                       F.count(col("records_ch_flag_" + str(ind)))\
                                                      .over(w1))

Solution

  • i'm not sure what loop are you running, but here's an implementation with list comprehension within a select.

    data_ls = [
        (10, 11, 'foo', 'foo'),
        (12, 12, 'bar', 'bar'),
        (10, 12, 'foo', 'bar')
    ]
    
    data_sdf = spark.sparkContext.parallelize(data_ls). \
        toDF(['marks_1', 'marks_2', 'qualification_1', 'qualification_2'])
    
    col_pairs = ['marks','qualification']
    
    data_sdf. \
        select('*',
               *[(func.col(c+'_1') == func.col(c+'_2')).cast('int').alias(c+'_check') for c in col_pairs]
               ). \
        show()
    
    # +-------+-------+---------------+---------------+-----------+-------------------+
    # |marks_1|marks_2|qualification_1|qualification_2|marks_check|qualification_check|
    # +-------+-------+---------------+---------------+-----------+-------------------+
    # |     10|     11|            foo|            foo|          0|                  1|
    # |     12|     12|            bar|            bar|          1|                  1|
    # |     10|     12|            foo|            bar|          0|                  0|
    # +-------+-------+---------------+---------------+-----------+-------------------+
    

    where the list comprehension would yield the following

    [(func.col(c+'_1') == func.col(c+'_2')).cast('int').alias(c+'_check') for c in col_pairs]
    # [Column<'CAST((marks_1 = marks_2) AS INT) AS `marks_check`'>,
    #  Column<'CAST((qualification_1 = qualification_2) AS INT) AS `qualification_check`'>]
    

    EDIT

    based on the additional (updated) info, you need the count of unmatched records for that pair and then you want to calculate the unmatched percentage.

    reversing the aforementioned logic to count the unmatched records

    col_pairs = ['marks','qualification']
    
    data_sdf. \
        agg(*[func.sum((func.col(c+'_1') != func.col(c+'_2')).cast('int')).alias(c+'_unmatch') for c in col_pairs],
            func.count('*').alias('row_cnt')
            ). \
        select('*',
               *[(func.col(c+'_unmatch') / func.col('row_cnt')).alias(c+'_unmatch_perc') for c in col_pairs]
               ). \
        show()
    
    # +-------------+---------------------+-------+------------------+--------------------------+
    # |marks_unmatch|qualification_unmatch|row_cnt|marks_unmatch_perc|qualification_unmatch_perc|
    # +-------------+---------------------+-------+------------------+--------------------------+
    # |            2|                    1|      3|0.6666666666666666|        0.3333333333333333|
    # +-------------+---------------------+-------+------------------+--------------------------+
    

    the code flags (as 1) the records where the pair does not match and takes a sum of the flag - which gives us the pair's unmatched record count. dividing that with the total row count will give the percentage.

    the list comprehension will yield the following

    [func.sum((func.col(c+'_1') != func.col(c+'_2')).cast('int')).alias(c+'_unmatch') for c in col_pairs]
    # [Column<'sum(CAST((NOT (marks_1 = marks_2)) AS INT)) AS `marks_unmatch`'>,
    #  Column<'sum(CAST((NOT (qualification_1 = qualification_2)) AS INT)) AS `qualification_unmatch`'>]
    

    this is very much efficient as all of it happens in a single select statement which will only project once in the spark plan as opposed to your approach which will project every time you do a withColumn - and that is inefficient to spark.