Search code examples
pysparkfilterdata-analysis

PySpark filtering


Currently I'm making calculations on a database that contains information on how loans are paid by borrowers.

From technical point of view, I'm using PySpark and have just faced with an issue of how to use advanced filtering operations.

My dataframe looks like this:

ID     ContractDate LoanSum  ClosingDate Status Bank
ID3    2024-06-10   20                   Active A
ID3    2024-06-11   30                   Active A
ID3    2024-06-12   50                   Active A
ID3    2024-06-12   15                   Active B
ID3    2024-06-12   5        2024-06-18  Closed A
ID3    2024-06-13   40       2024-06-20  Closed A
ID3    2024-06-22   50                   Active A
ID4    2024-07-11   20                   Active A
ID4    2024-07-12   30                   Active B
ID4    2024-07-13   50                   Active B
ID4    2024-07-11   5        2024-08-20  Closed A

My goal is to calculate sum by "loansum" field for the borrowers who have 3 and more only active loans issued by the same bank within 3 days from the date the first credit was issued.

In my case it will be the sum of 20 + 30 + 50 = 100 for ID3

What I have done so far:

from pyspark.sql import functions as f
from pyspark.sql import Window

df = spark.createDataFrame(data).toDF('ID','ContractDate','LoanSum','ClosingDate', 'Status', 'Bank')
df.show()

cols = df.columns
w = Window.partitionBy('ID').orderBy('ContractDate')

df.withColumn('PreviousContractDate', f.lag('ContractDate').over(w)) \
  .withColumn('Target', f.expr('datediff(ContractDate, PreviousContractDate) >= 0 & datediff(ContractDate, PreviousContractDate) <= 3')) \
  .withColumn('Target', f.col('Target') | f.lead('Target').over(w)) \
  .filter('Target == True')

This code helps only to catch loans issued to one borrower depending on ContractDate.

How can I add more conditions?


Solution

  • To resolve your issue please follow below code. For sample i am using above dataframe.

    Code:

    from pyspark.sql import functions as f
    from pyspark.sql import Window
    
    data1 = [
        ('ID3', '2024-06-10', 20, None, 'Active', 'A'),
        ('ID3', '2024-06-11', 30, None, 'Active', 'A'),
        ('ID3', '2024-06-12', 50, None, 'Active', 'A'),
        ('ID3', '2024-06-12', 15, None, 'Active', 'B'),
        ('ID3', '2024-06-12', 5, '2024-06-18', 'Closed', 'A'),
        ('ID3', '2024-06-13', 40, '2024-06-20', 'Closed', 'A'),
        ('ID3', '2024-06-22', 50, None, 'Active', 'A'),
        ('ID4', '2024-07-11', 20, None, 'Active', 'A'),
        ('ID4', '2024-07-12', 30, None, 'Active', 'B'),
        ('ID4', '2024-07-13', 50, None, 'Active', 'B'),
        ('ID4', '2024-07-11', 5, '2024-08-20', 'Closed', 'A'),
    ]
    
    
    df12 = spark.createDataFrame(data1, ['ID', 'ContractDate', 'LoanSum', 'ClosingDate', 'Status', 'Bank']) \
        .withColumn('ContractDate', f.to_date('ContractDate')) \
        .filter(f.col('Status') == 'Active')
    
    # Use  window function and calculate cumulative count
    w = Window.partitionBy('ID', 'Bank').orderBy('ContractDate')
    df = df12.withColumn('CumulativeCount', f.sum(
        f.when(f.datediff(f.col('ContractDate'), f.lag('ContractDate').over(w)).isNull(), 1)
        .when(f.datediff(f.col('ContractDate'), f.lag('ContractDate').over(w)) <= 3, 1)
        .otherwise(0)
    ).over(w))
    
    df1 = df.filter(f.col('CumulativeCount') >= 3).groupBy('ID', 'Bank').agg(f.sum('LoanSum').alias('TotalLoanSum'))
    
    
    display(df1)
    

    Output:

    enter image description here