Search code examples
pythonpysparksubsetdata-manipulation

Marking the results of code inside initial dataframe in python


Currently I'm performing calculations on a database that contains information on transactions. It is a big dataset consumes a lot of resources and have just faced with an issue of how to use optimize my current solution.

My initial dataframe looks like this:

Name    ID     ContractDate LoanSum Status
A       ID1    2022-10-10   10      Closed 
A       ID1    2022-10-15   13      Active
A       ID1    2022-10-30   20      Active
B       ID2    2022-11-05   30      Active
C       ID3    2022-12-10   40      Closed
C       ID3    2022-12-12   43      Active
C       ID3    2022-12-19   46      Active
D       ID4    2022-12-10   10      Closed
D       ID4    2022-12-12   30      Active

I have to create a dataframe that contains all loans issued to specific borrowers (grouped by ID) where the number of days between two loans (assigned to one unique ID) is less than 15 and the difference between loansums issued to one specific borrower is less or equal then 3.

My solution:

from pyspark.sql import functions as f
from pyspark.sql import Window

df = spark.createDataFrame(data).toDF('Name','ID','ContractDate','LoanSum','Status')
df.show()

cols = df.columns
w = Window.partitionBy('ID').orderBy('ContractDate')

new_df = df.withColumn('PreviousContractDate', f.lag('ContractDate').over(w)) \
  .withColumn('PreviousLoanSum', f.lag('LoanSum').over(w)) \
  .withColumn('Target', f.expr('datediff(ContractDate, PreviousContractDate) < 15 and LoanSum - PreviousLoanSum <= 3')) \
  .withColumn('Target', f.col('Target') | f.lead('Target').over(w)) \
  .filter('Target == True') \
  .select(cols[0], *cols[1:])

+----+---+------------+-------+------+
|Name| ID|ContractDate|LoanSum|Status|
+----+---+------------+-------+------+
|   A|ID1|  2022-10-10|     10|Closed|
|   A|ID1|  2022-10-15|     13|Active|
|   C|ID3|  2022-12-10|     40|Closed|
|   C|ID3|  2022-12-12|     43|Active|
|   C|ID3|  2022-12-19|     46|Active|
+----+---+------------+-------+------+

As you can see my results are stored in a separate table. My next goal is to remove dataframe “new_df” from initial dataframe “df” in order to work with related rows.

If I use this obvious solution, the system works super slow especially when I have to subtract dataframes one-by-one on each step:

df_sub = df.subtract(new_df)

My question: is it possible (if yes then how) not to create new dataframe but separate rows that are included in dataframe new_df inside the first dataframe df? Maybe to mark the rows in a special way by creating also a new column in order to filter the rows needed for further analysis later?

Thank you in advance!


Solution

  • You can do it 2 ways

    1. Use left anti join
    2. Instead of creating another table add your target flag in same table

    Option 1 -

    Left anti join - Keeps rows of left table that don't have any matching rows from the right table. More on left anti join here and here

    new_df = df.withColumn('PreviousContractDate', f.lag('ContractDate').over(w)) \
      .withColumn('PreviousLoanSum', f.lag('LoanSum').over(w)) \
      .withColumn('Target', f.expr('datediff(ContractDate, PreviousContractDate) < 15 and LoanSum - PreviousLoanSum <= 3')) \
      .withColumn('RemoverRowsFlag', f.col('Target') | f.lead('Target').over(w)) \
      .filter('RemoverRowsFlag == True') \
      .select(cols[0], *cols[1:],'RemoverRowsFlag')
    
    df = df.join(new_df,on=[*cols], how='left_anti')
    df.show()
    

    Output

    +----+---+------------+-------+------+
    |Name| ID|ContractDate|LoanSum|Status|
    +----+---+------------+-------+------+
    |   A|ID1|  2022-10-30|     20|Active|
    |   B|ID2|  2022-11-05|     30|Active|
    |   D|ID4|  2022-12-10|     10|Closed|
    |   D|ID4|  2022-12-12|     30|Active|
    +----+---+------------+-------+------+
    

    Option 2 -

    This is pretty straight forward, add the column in the same table and filter it.

    df = df.withColumn('PreviousContractDate', f.lag('ContractDate').over(w)) \
      .withColumn('PreviousLoanSum', f.lag('LoanSum').over(w)) \
      .withColumn('Target', f.expr('datediff(ContractDate, PreviousContractDate) < 15 and LoanSum - PreviousLoanSum <= 3')) \
      .withColumn('RemoverRowsFlag', f.col('Target') | f.lead('Target').over(w)) \
      .select(*cols,'RemoverRowsFlag')
    
    df = df.filter('RemoverRowsFlag == True')
    df.show()
    

    Output

    +----+---+------------+-------+------+---------------+
    |Name| ID|ContractDate|LoanSum|Status|RemoverRowsFlag|
    +----+---+------------+-------+------+---------------+
    |   A|ID1|  2022-10-10|     10|Closed|           true|
    |   A|ID1|  2022-10-15|     13|Active|           true|
    |   C|ID3|  2022-12-10|     40|Closed|           true|
    |   C|ID3|  2022-12-12|     43|Active|           true|
    |   C|ID3|  2022-12-19|     46|Active|           true|
    +----+---+------------+-------+------+---------------+
    

    enter image description here