Search code examples
pythonapache-sparkpysparkapache-spark-sql

Pyspark - Reject Values based on multiple conditions


I have a dataframe with values:

+--------------+---------------+------------------+------------+-----+
|registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|
+--------------+---------------+------------------+------------+-----+
|      XXXXXXXX|              1|                 1|   000000001|1
|      XXXXXXXX|              1|                 1|   000000000|2
|      XXXXXXXX|              1|                 1|   000000001|3
|      XXXXXXXX|              1|                 1|   000000000|4
|      XXXXXXXX|              1|                 1|   000000001|5
|      XXXXXXXX|              1|                 1|   000000000|6
|      XXXXXXXX|              1|                 1|   000000001|7
|      XXXXXXXX|              1|                 1|   000000000|8
|      XXXXXXXX|               |                 1|   000000001|9
|      XXXXXXXX|              1|                 1|   000000000|10
|      XXXXXXXX|              1|                 1|   000000000|11
|      XXXXXXXX|              1|                 1|   000000001|12
|      XXXXXXXX|              1|                 1|   000000000|13
|      XXXXXXXX|              1|                 1|   000000001|14
|      XXXXXXXX|              1|                 1|   000000000|15
|      YYYYYYYY|              1|                 1|   000000001|16
|      YYYYYYYY|              1|                 1|   000000001|17
|      YYYYYYYY|              1|                 1|   000000000|18
|      YYYYYYYY|              1|                 1|   000000001|19
|      ZZZZZZZZ|              1|                 2|   000000001|20
|      ZZZZZZZZ|              1|                 2|   000000000|21
|      ZZZZZZZZ|              1|                 2|   000000000|22
|      ZZZZZZZZ|               |                 2|   000000001|23

Each registerNumber represents set of records. Where joint_indicator is 1, it represents joint shareholding and where it's blank it represents individual shareholder. First 1 joint_indicator starts the joint holding and must have total no_of_shares held (must be greater than zero), all subsequent records must have 0 no_of_shares.

Reject Values For Conditions:

  1. When the joint_indicator is blank, the no_of_shares will be greater than 0, the next record after that with joint_indicator being populated, the no_of_shares shouldn't be 0 and it's a error record. Example:
+--------------+---------------+------------------+------------+-----+
|registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|
+--------------+---------------+------------------+------------+-----+
|      XXXXXXXX|               |                 1|   000000001|9    |
|      XXXXXXXX|              1|                 1|   000000000|10   |
  1. For 2nd set of records First 1 joint_indicator starts the joint holding and have total no_of_shares held (must be greater than zero), all subsequent records must have 0 no_of_shares. Here subsequent record isn't equal to 0. So this is also a error record Example:
+--------------+---------------+------------------+------------+-----+
|registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|
+--------------+---------------+------------------+------------+-----+
|      YYYYYYYY|              1|                 1|   000000001|16
|      YYYYYYYY|              1|                 1|   000000001|17

Whereas this shouldn't give me an error: Example:

+--------------+---------------+------------------+------------+-----+
|registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|
+--------------+---------------+------------------+------------+-----+
|      YYYYYYYY|              1|                 1|   000000001|16
|      YYYYYYYY|              1|                 1|   000000000|17

I am suppose to fetch registerNumber for the error records Any leads to solve this would be highly appreciated. Expected Output:

+--------------+---------------+------------------+------------+-----+-------+
|registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|error  |
+--------------+---------------+------------------+------------+-----+-------+
|      XXXXXXXX|              1|                 1|   000000001|    1|  false| 
|      XXXXXXXX|              1|                 1|   000000000|    2|  false|
|      XXXXXXXX|              1|                 1|   000000001|    3|  false| 
|      XXXXXXXX|              1|                 1|   000000000|    4|  false|
|      XXXXXXXX|              1|                 1|   000000001|    5|  false|
|      XXXXXXXX|              1|                 1|   000000000|    6|  false|
|      XXXXXXXX|              1|                 1|   000000001|    7|  false|
|      XXXXXXXX|              1|                 1|   000000000|    8|  false|
|      XXXXXXXX|           null|                 1|   000000001|    9|  false| 
|      XXXXXXXX|              1|                 1|   000000000|   10|   true| 
|      XXXXXXXX|              1|                 1|   000000000|   11|  false|
|      XXXXXXXX|              1|                 1|   000000001|   12|  false|
|      XXXXXXXX|              1|                 1|   000000000|   13|  false|
|      XXXXXXXX|              1|                 1|   000000001|   14|  false|
|      XXXXXXXX|              1|                 1|   000000000|   15|  false|
|      YYYYYYYY|              1|                 1|   000000001|   16|  false|
|      YYYYYYYY|              1|                 1|   000000001|   17|  true |
|      ZZZZZZZZ|              1|                 2|   000000000|   22|  false|
|      YYYYYYYY|              1|                 1|   000000000|   18|  false|
|      YYYYYYYY|              1|                 1|   000000001|   19|  true |
|      ZZZZZZZZ|           null|                 2|   000000001|   23|  false|
|      ZZZZZZZZ|              1|                 2|   000000001|   20|  false|
|      ZZZZZZZZ|              1|                 2|   000000000|   21|  false|

Solution

  • From what I understand, a record is in error if:

    • the previous joint_indicator is null and no_of_shares is 0
    • the previous joint_indicator is not null, the previous no_of_shares is not 0 and the current no_of_shares is not 0.
    • the joint_indicator is not null, no_of_shares is not 0 and it is the last one (we check that the next id is null in the window)

    You can simply check that with a window ordered by mn_id and partitioned by registerNumber like this:

    from pyspark.sql import Window
    win = Window.partitionBy("registerNumber").orderBy("mn_id")
    
    result = df\
      .withColumn("error_1",
           F.lag("joint_indicator").over(win).isNull() &
           F.col("no_of_shares").rlike("^0+$")
      ).withColumn("error_2",
           F.lag("joint_indicator").over(win).isNotNull() &
           ~F.col("no_of_shares").rlike("^0+$") &
           ~F.lag(F.col("no_of_shares")).over(win).rlike("^0+$")
      ).withColumn("error_3",
           F.col("joint_indicator").isNotNull() &
           ~F.col("no_of_shares").rlike("^0+$") &
           F.lead("mn_id").over(win).isNull()
      )
    result.show(30)
    
    |registerNumber|joint_indicator|share_class_number|no_of_shares|mn_id|error_1|error_2|error_3|
    +--------------+---------------+------------------+------------+-----+-------+-------+-------+
    |      XXXXXXXX|              1|                 1|   000000001|    1|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|    2|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000001|    3|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|    4|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000001|    5|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|    6|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000001|    7|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|    8|  false|  false|  false|
    |      XXXXXXXX|           null|                 1|   000000001|    9|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|   10|   true|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|   11|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000001|   12|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|   13|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000001|   14|  false|  false|  false|
    |      XXXXXXXX|              1|                 1|   000000000|   15|  false|  false|  false|
    |      YYYYYYYY|              1|                 1|   000000001|   16|  false|  false|  false|
    |      YYYYYYYY|              1|                 1|   000000001|   17|  false|   true|  false|
    |      YYYYYYYY|              1|                 1|   000000000|   18|  false|  false|  false|
    |      YYYYYYYY|              1|                 1|   000000001|   19|  false|  false|   true|
    |      ZZZZZZZZ|              1|                 2|   000000001|   20|  false|  false|  false|
    |      ZZZZZZZZ|              1|                 2|   000000000|   21|  false|  false|  false|
    |      ZZZZZZZZ|              1|                 2|   000000000|   22|  false|  false|  false|
    |      ZZZZZZZZ|           null|                 2|   000000001|   23|  false|  false|  false|
    +--------------+---------------+------------------+------------+-----+-------+-------+-------+
    

    I separated the 3 error cases for clarity. You may then do

    result = result\
        .withColumn("error", F.col("error_1") | F.col("error_2") | F.col("error_3"))\
        .drop("error_1", "error_2", "error_3")