Search code examples
apache-sparkpysparkdatabricksazure-databricks

Complex Joins (Pyspark) - Range and Categorical


table 1:

pol_no version cover rf_1 rf_2 rf_3 rl_1 rl_2 rl_3
abc123 1 a ["base","usage"] ["group"] null ["500","private"] ["blue"] null
cde111 1 a ["base","usage"] ["age"] ["protection","gold_mem"," more_than_one_claim"] ["500","private"] ["9"] ["Y","Y","N"]
cde222 1 a ["base","usage"] ["group"] null ["300","business"] ["gold"] null

table 2:

rating_factor_1 rating_factor_2 rating_factor_3 rating_factor_cat_a rating_factor_cat_b rating_factor_cat_c rating_factor_start rating_factor_end rating_factor_amt rating_factor_coeff rating_factor_rate version cover
base usage null private null null 400 550 50 0.2 null 1 a
base usage null business null null 200 300 70 0.4 null 1 a
group null null blue null null null null 20 null 0.5 1 a
group null null gold null null null null 30 null 0.8 1 a
protection gold_mem more_than_one_claim Y Y N null null 10 0.4 null 1 a
age null null null null null 0 5 15 0.5 null 1 a
age null null null null null 6 10 20 0.8 null 1 a

I have to join Table 1 to Table 2. The catch is, there are both range and categorical joins. I will join them by version and cover. Then, for instance rf_1 in table 1 would match with rating _ factor 1 and rating _ factor 2 in table 2 and looking at values from rl_1, i would have have to join using range with table 2 and join the second factor "private" with rating _ factor _ cat _a. And append the amount, coeff and rate values back to table 1.

Output table:

pol_no version cover rf_1 rf_2 rf_3 rl_1 rl_2 rl_3 amt_1 coeff_1 rate_1 amt_2 coeff_2 rate_2 amt_3 coeff_3 rate_3
abc123 1 a ["base","usage"] ["group"] null ["500","private"] ["blue"] null 50 0.2 null 20 null 0.5 null null null
cde111 1 a ["base","usage"] ["age"] ["protection","gold_mem"," more_than_one_claim"] ["500","private"] ["9"] ["Y","Y","N"] 50 0.2 null 20 0.8 null 10 0.4 null
cde222 1 a ["base","usage"] ["group"] null ["300","business"] ["gold"] null 70 0.4 null 30 null 0.8 null null null

I think we would have to do permutations of various joins such as [category, range] , [range,category], [category,category] and [range,range] and see which joins gives at least one of the values amount,coeff,rate in the final table.

I cannot use hardcoding to get the amount,coeff and rate values, have to use loops as in actual case there are 10000 of rows of data like this.

Code modified from sparksql to pyspark:

# Constructing the list of expressions for the select statement
query_expressions = []

for i in range(1, 4):
    rf = f'rf_{i}'
    rl = f'rl_{i}'
    print(rl)
    jtmp_col = f'jtmp_{i}'

    expr = (
        F.when(
            (F.col(rf).isNotNull()) &
            (F.size(F.col('tab2_cat_values')) == 0) &
            (
                (rl[0] >= F.col('rating_factor_start')) &
                #print(rl[0])
                (rl[0] <= F.col('rating_factor_end'))
            ),
            F.struct(
                'rating_factor_amount',
                'rating_factor_coefficeint',
                'rating_factor_rate'
            )
        )
        # Add other when conditions similarly
        .alias(jtmp_col)
    )

    query_expressions.extend([
        expr,
        F.col(f'{jtmp_col}.rating_factor_amount').alias(f'amt_{i}'),
        F.col(f'{jtmp_col}.rating_factor_coefficeint').alias(f'coeff_{i}'),
        F.col(f'{jtmp_col}.rating_factor_rate').alias(f'rate_{i}')
    ])

# Creating the DataFrame
df = (
    df1.join(df2, (['version','cover']))
    .select(
        "pol_no", "version", "cover", "rf_1", "rl_1", "rf_2", "rf_3", "rl_2", "rl_3","rating_factor_1","rating_factor_2","rating_factor_3",
        F.array(F.col("rating_factor_cat_a"), F.col("rating_factor_cat_b"), F.col("rating_factor_cat_c")).alias("tab2_cat_values"),
        *query_expressions
    )
     .filter(
         (
             (F.array_compact(F.array(F.col("rating_factor_1"), F.col("rating_factor_2"), F.col("rating_factor_3")))
              == F.col("rf_1")) |
             (F.array_compact(F.array(F.col("rating_factor_1"), F.col("rating_factor_2"), F.col("rating_factor_3")))
              == F.col("rf_2")) |
             (F.array_compact(F.array(F.col("rating_factor_1"), F.col("rating_factor_2"), F.col("rating_factor_3")))
              == F.col("rf_3"))
         )
     )
)

# Show or perform any further actions as needed
df.display()

Solution

  • What you are asking about join based on range including category simultaneously for all rf factor is not possible with join operations, but combination of filter and join can achieve your output.

    Here are the queries for it.

    First create a temporary view of these 2 tables.

    df1.createOrReplaceTempView("df1")
    df2.createOrReplaceTempView("df2")
    

    Then creating query on conditions which satisfies all rf factors.

    querycon=''
    for i in range(1,4):
        rf = f'rf_{i}'
        rl= f'rl_{i}'
        q = f''',case
    when ((d1.{rf} is not null) and (tab2_cat_values==array()) and ((cast(d1.{rl}[0] as int) >= d2.rating_factor_start) and (cast(d1.{rl}[0] as int) <= d2.rating_factor_end))) then struct(d2.rating_factor_amt,d2.rating_factor_coeff,d2.rating_factor_rate)
    when ((d1.{rf} is not null) and (tab2_cat_values!=array()) and (array_size(d1.{rl})==array_size(tab2_cat_values) and (d1.{rl}==tab2_cat_values))) then struct(d2.rating_factor_amt,d2.rating_factor_coeff,d2.rating_factor_rate)
    when ((d1.{rf} is not null) and (tab2_cat_values!=array()) and (array_size(d1.{rl})>1) and (array_contains(d1.{rl},tab2_cat_values[0])) and ((cast(d1.{rl}[0] as int) >= d2.rating_factor_start) and (cast(d1.{rl}[0] as int) <= d2.rating_factor_end))) then 
    struct(d2.rating_factor_amt,d2.rating_factor_coeff,d2.rating_factor_rate)
    else null end as jtmp_{i},
    jtmp_{i}.rating_factor_amt as amt_{i},
    jtmp_{i}.rating_factor_coeff as coeff_{i},
    jtmp_{i}.rating_factor_rate as rate_{i}
    '''
        querycon=querycon+q
    

    Here, I am looping through all factors in your case it's 3 if more factors comes in future you need to modify above cases accordingly.

    In this case statement

    1. It checks for category, if it is empty then it goes for checking range and takes output.
    2. Then it checks if category are same in both table
    3. Last if the there are more category then goes to check both category and range conditions.

    This query is combined with the main joined data like below.

    fquery=f'''select
    d1.pol_no,d1.version,d1.cover,d1.rf_1,d1.rl_1,d1.rf_2,d1.rf_3,d1.rl_2,d1.rl_3,
    (array_compact(array(d2.rating_factor_cat_a,d2.rating_factor_cat_b,d2.rating_factor_cat_c))) as tab2_cat_values
    {querycon}
    from df1 d1 
    join df2 d2 on d1.version=d2.version 
    and d1.cover=d2.cover 
    and  (
        (array_compact(array(d2.rating_factor_1,d2.rating_factor_2,d2.rating_factor_3))==d1.rf_1)
        or (array_compact(array(d2.rating_factor_1,d2.rating_factor_2,d2.rating_factor_3))==d1.rf_2)
        or (array_compact(array(d2.rating_factor_1,d2.rating_factor_2,d2.rating_factor_3))==d1.rf_3))
    '''
    
    df = spark.sql(fquery)
    

    Next select the required columns.

    from pyspark.sql.functions import collect_set,explode_outer,col
    
    result_cols=["amt_1","coeff_1","rate_1","amt_2","coeff_2","rate_2","amt_3","coeff_3","rate_3"]
    aggexpr = [collect_set(i).alias(i) for i in result_cols]
    df = df.groupBy(*df.columns[:9]).agg(*aggexpr)
    
    for i in result_cols:
        df = df.withColumn(i,explode_outer(col(i)))
    display(df)
    

    Here, *df.columns[:9] is the columns from table 1 till rl_3. If you have more factors in future you need to extend the list till it.

    Output:

    pol_no version cover rf_1 rl_1 rf_2 rf_3 rl_2 rl_3 amt_1 coeff_1 rate_1 amt_2 coeff_2 rate_2 amt_3 coeff_3 rate_3
    abc123 1 a ["base","usage"] ["500","private"] ["group"] null ["blue"] null 50 0.2 null 20 0.5 null null null null
    cde111 1 a ["base","usage"] ["500","private"] ["age"] ["protection","gold_mem","more_than_one_claim"] ["9"] ["Y","Y","N"] 50 0.2 null 20 0.8 null 10 0.4 null
    cde222 1 a ["base","usage"] ["300","business"] ["group"] null ["gold"] null 70 0.4 null 30 0.8 null null null null

    enter image description here