Search code examples
apache-sparkpysparkfilterconditional-statementsrow

Remove rows where groups of two columns have differences


Is it possible to remove rows if the values in the Block column occurs at least twice which has different values in the ID column?

My data looks like this:

ID Block
1 A
1 C
1 C
3 A
3 B

In the above case, the value A in the Block column occurs twice, which has values 1 and 3 in the ID column. So the rows are removed.

The expected output should be:

ID Block
1 C
1 C
3 B

I tried to use the dropDuplicates after the groupBy, but I don't know how to filter with this type of condition. It appears that I would need a set for the Block column to check with the ID column.


Solution

  • One way to do it is using window functions. The first one (lag) marks the row if it is different than the previous. The second (sum) marks all "Block" rows for previously marked rows. Lastly, deleting roes and the helper (_flag) column.

    Input:

    from pyspark.sql import functions as F, Window as W
    df = spark.createDataFrame(
        [(1, 'A'),
         (1, 'C'),
         (1, 'C'),
         (3, 'A'),
         (3, 'B')],
        ['ID', 'Block'])
    

    Script:

    w1 = W.partitionBy('Block').orderBy('ID')
    w2 = W.partitionBy('Block')
    grp = F.when(F.lag('ID').over(w1) != F.col('ID'), 1).otherwise(0)
    df = df.withColumn('_flag', F.sum(grp).over(w2) == 0) \
        .filter('_flag').drop('_flag')
    
    df.show()
    # +---+-----+
    # | ID|Block|
    # +---+-----+
    # |  3|    B|
    # |  1|    C|
    # |  1|    C|
    # +---+-----+