Search code examples
pythonpyspark

PySpark/Python - best way to create new calculated columns from variable number of column inputs


Given this data structure

  • age_group{#}

    • is the count of records that fit into a user defined set of age groups (ex: age_group1 = the count of records between ages 0-10, age_group2 is 11-21...). There can be up to 10 age groups.
  • bin_number

    • user can define multiple age groupings. Ex: bin 1 is age_group1 = 0-10, age_group2 = 11-21, bin 2 is age_group1 = 0-17, age_group2 = 18-44. There can be any number of bins of age groups.
  • t_group

    • The total population from the census for this age group by county, where the format of the column name represents t_{age group #}_{bin #}. Ex: t_group1_1 is the total population for ages 0-10 in bin 1.
county bin_number age_group1 age_group2 t_group1_1 t_group2_1 t_group2_2 t_group2_2
01001 1 5 10 200 100 300 400
01001 2 1 2 100 200 300 400
01003 1 5 10 200 100 300 400
01003 2 1 2 100 200 300 400

The goal is to add new columns with the following calculations. Based on this cut of sample data, the new columns would be:
where bin_number = 1: (age_group1 / t_group1_1) * 100000
where bin_number = 1: (age_group2 / t_group2_1) * 100000
where bin_number = 2: (age_group1 / t_group1_2) * 100000
where bin_number = 2: (age_group2 / t_group2_2) * 100000

The best I have come up with is to loop through the number of bins (hardcoded for the 2 in this sample), filter by the bin_number, select/alias the calculated columns, and then union each separate dataframe back together.

I'm stuck trying to think of other ways to do this that might be cleaner/more efficient. I keep coming back to groupBy or window partitions but the complexity of the data structure has me lost.

Note: The data structure is flexible and can be changed to simplify a solution, but I do need the final output to be in the same structure as the table above. Thank you for any feedback!

    dfs = []
    while i <= 2:
        dfGroup = df.filter(F.col("bin_number") == i)  # df is the table in this post
        
        totalBins = [x for x in df.columns if x.startswith("t_group") and x.endswith(str(i))]
        dfGroup = dfGroup.select(
            "*",
            *[((F.col(f"age_group{x}") / F.col(f"t_group{x}_{i}")) * RATE).alias(
                    f"crude_rate_age_group_bin_{x}"
                ) for x in range(1, len(totalBins))],
        )
       
        dfs.append(dfGroup)

        i += i

    dfRate = reduce(F.DataFrame.unionAll, dfs)

Solution

  • Here's one way to do it. I took your question as a challenge.

    import sys
    
    import pyspark.sql.functions as F
    from pyspark import SparkContext, SQLContext
    from pyspark.sql.functions import struct, udf
    from pyspark.sql import SparkSession
    from pyspark.sql.functions import (col, to_date,)
    from pyspark.sql import Row
    from pyspark.sql.functions import *
    from pyspark.sql.types import *
    from pyspark import SparkContext, SQLContext, SparkConf
    
    sc = SparkContext('local')
    sqlContext = SQLContext(sc)
    
    
    data1 = [
        ["01001", 1, 5, 10, 200, 100, 300, 400],
        ["01001", 2, 1, 2, 100, 200, 300, 400],
        ["01003", 1, 5, 10, 200, 100, 300, 400],
        ["01003", 2, 1, 2, 100, 200, 300, 400],
    
          ]
    
    df1Columns = ["county", "bin_number",   "age_group1",   "age_group2",   "t_group1_1",   "t_group2_1",   "t_group1_2",   "t_group2_2"]
    df1 = sqlContext.createDataFrame(data=data1, schema = df1Columns)
    
    columns_list = list(df1.columns)
    print(columns_list)
    
    list_bin_values = df1.select('bin_number').distinct().rdd.flatMap(lambda x: x).collect()
    print(list_bin_values)
    
    bin_tgroup_mapping = []
    for bin_ele in list_bin_values:
        totalBinsModified = [x for x in columns_list if x.startswith("t_group") and x.endswith(str(bin_ele))]
        #totalBins is a list
        #while bin_ele is a number
        bin_tgroup_mapping.append((bin_ele, totalBinsModified))
    
    schema = StructType([
        StructField("bin1_group1", FloatType(), True),
        StructField("bin1_group2", FloatType(), True),
        StructField("bin2_group1", FloatType(), True),
        StructField("bin2_group2", FloatType(), True),
    ])
    
    print(bin_tgroup_mapping)
    
    
    RATE = 100000
    def evalutate_helper(row):
        all_bin_result_set = []
        ii = row[1]
        single_bin_result_set = []
    
        for tt in bin_tgroup_mapping :
            if tt[0] == ii :
                xx = list(range(1, len(tt[1])+1))
                for x_ele in xx:
                    age_group_col = columns_list.index(f"age_group{x_ele}")
                    t_group_col = columns_list.index(f"t_group{x_ele}_{ii}")
                    inter_result = (row[age_group_col] / row[t_group_col]) * RATE
                    single_bin_result_set.append(inter_result)
                all_bin_result_set.append(single_bin_result_set)
            else:
                xx = tt[1]
                temp_list = [None]*len(xx)
                all_bin_result_set.append(temp_list)
    
        flat_answer_list = [item for sublist in all_bin_result_set for item in sublist]
        return flat_answer_list
    
    
    calculate_main_udf = udf(lambda row : evalutate_helper(row), schema)
    
    answer_df = df1.select(*df1.columns,  calculate_main_udf(struct([df1[col] for col in columns_list])).alias("struct_combo"))
    print("answer_df dataframe")
    answer_df.show(truncate=False)
    
    
    
    answer_df_separated = answer_df.select(*answer_df.columns,
                                           col("struct_combo.bin1_group1").alias("bin1_group1"),
                                           col("struct_combo.bin1_group2").alias("bin1_group2"),
                                           col("struct_combo.bin2_group1").alias("bin2_group1"),
                                           col("struct_combo.bin2_group2").alias("bin2_group2"))
    print("answer_df dataframe")
    answer_df_separated.show(truncate=False)
    

    Output is as follows :

    ['county', 'bin_number', 'age_group1', 'age_group2', 't_group1_1', 't_group2_1', 't_group1_2', 't_group2_2']
    [1, 2]
    [(1, ['t_group1_1', 't_group2_1']), (2, ['t_group1_2', 't_group2_2'])]
    answer_df dataframe
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+
    |county|bin_number|age_group1|age_group2|t_group1_1|t_group2_1|t_group1_2|t_group2_2|struct_combo                  |
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+
    |01001 |1         |5         |10        |200       |100       |300       |400       |{2500.0, 10000.0, null, null} |
    |01001 |2         |1         |2         |100       |200       |300       |400       |{null, null, 333.33334, 500.0}|
    |01003 |1         |5         |10        |200       |100       |300       |400       |{2500.0, 10000.0, null, null} |
    |01003 |2         |1         |2         |100       |200       |300       |400       |{null, null, 333.33334, 500.0}|
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+
    
    answer_df dataframe
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+-----------+-----------+-----------+-----------+
    |county|bin_number|age_group1|age_group2|t_group1_1|t_group2_1|t_group1_2|t_group2_2|struct_combo                  |bin1_group1|bin1_group2|bin2_group1|bin2_group2|
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+-----------+-----------+-----------+-----------+
    |01001 |1         |5         |10        |200       |100       |300       |400       |{2500.0, 10000.0, null, null} |2500.0     |10000.0    |null       |null       |
    |01001 |2         |1         |2         |100       |200       |300       |400       |{null, null, 333.33334, 500.0}|null       |null       |333.33334  |500.0      |
    |01003 |1         |5         |10        |200       |100       |300       |400       |{2500.0, 10000.0, null, null} |2500.0     |10000.0    |null       |null       |
    |01003 |2         |1         |2         |100       |200       |300       |400       |{null, null, 333.33334, 500.0}|null       |null       |333.33334  |500.0      |
    +------+----------+----------+----------+----------+----------+----------+----------+------------------------------+-----------+-----------+-----------+-----------+