Search code examples
azureapache-sparkpysparkazure-databricks

retrieving values from table itself with arrays (pyspark)


policies table:

Columns : pol_no, version, version_number,base,permitted_usage,discount_group,vehicle_age,rf_1,rf_2,rf_3

rf_1, rf_2 and rf_3 values are array values.

For instance, rf_1 = ["base","permitted_usage"]

Need to match the array values with the column names in policies table, retrieve the value and append it to rl_1 as an array as well.

Repeating these steps for the rest of the rf_index columns and append it to rl_index columns

pol_no base permitted_usage claims discount rf_1 rf_2 rl_1 rl_2
1 500 private 0 Y ['base','permitted_usage'] ['discount'] [500,'private'] ['Y']
2 600 business 1 N ['base','permitted_usage' ['discount'] [600,'business'] ['N']

thats how the final ouput should look like.

*please note that in the future there can be more RF columns rf_4,5,6 etc..


Solution

  • Dynamically select the rf columns

    from itertools import chain
    
    cols = df.select(df.colRegex(r'`rf_\d+`')).columns
    
    # ['rf_1', 'rf_2']
    

    Create a mapping which maps column name to corresponding column value for a given row.

    mapping = F.create_map(*chain(*[(F.lit(c), F.col(c)) for c in df.drop(*cols).columns]))
    
    # Column<'map(pol_no, pol_no, base, base, permitted_usage, permitted_usage, claims, claims, discount, discount)'>
    

    For each rf column, apply a transformation function on the items inside the list to substitute the matching values from the mapping

    new_cols = [F.transform(c, lambda v: mapping[v]).alias(c.replace('f', 'l')) for c in cols]
    result = df.select('*', *new_cols)
    
    # +------+----+---------------+------+--------+--------------------+----------+---------------+----+
    # |pol_no|base|permitted_usage|claims|discount|                rf_1|      rf_2|           rl_1|rl_2|
    # +------+----+---------------+------+--------+--------------------+----------+---------------+----+
    # |     1| 500|        private|     0|       Y|[base, permitted_...|[discount]| [500, private]| [Y]|
    # |     2| 600|       business|     1|       N|[base, permitted_...|[discount]|[600, business]| [N]|
    # +------+----+---------------+------+--------+--------------------+----------+---------------+----+