Search code examples
pythonpysparkuser-defined-functions

How do I write a Pyspark UDF to generate all possible combinations of column totals?


I have the following code which creates a new column based on combinations of columns in my dataframe, minus duplicates:

import itertools as it
import pandas as pd 

df = pd.DataFrame({
  'a': [3,4,5,6,3], 
  'b': [5,7,1,0,5], 
  'c': [3,4,2,1,3], 
  'd': [2,0,1,5,9]
})

orig_cols = df.columns 
for r in range(2, df.shape[1] + 1):
    for cols in it.combinations(orig_cols, r):
        df["_".join(cols)] = df.loc[:, cols].sum(axis=1)

df

enter image description here

I need to generate the same results using Pyspark through a UDF. What would be the equivalent code in Pyspark?


Solution

  • There's no need to use UDF. Let us use native spark functions:

    from itertools import combinations
    
    sums = [
        sum(map(F.col, c)).alias('_'.join(c)) 
        for r in range(2, len(df.columns) + 1) 
        for c in combinations(df.columns,   r)
    ]
    
    df = df.select('*', *sums)
    

    df.show()
    
    +---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+
    |  a|  b|  c|  d|a_b|a_c|a_d|b_c|b_d|c_d|a_b_c|a_b_d|a_c_d|b_c_d|a_b_c_d|
    +---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+
    |  3|  5|  3|  2|  8|  6|  5|  8|  7|  5|   11|   10|    8|   10|     13|
    |  4|  7|  4|  0| 11|  8|  4| 11|  7|  4|   15|   11|    8|   11|     15|
    |  5|  1|  2|  1|  6|  7|  6|  3|  2|  3|    8|    7|    8|    4|      9|
    |  6|  0|  1|  5|  6|  7| 11|  1|  5|  6|    7|   11|   12|    6|     12|
    |  3|  5|  3|  9|  8|  6| 12|  8| 14| 12|   11|   17|   15|   17|     20|
    +---+---+---+---+---+---+---+---+---+---+-----+-----+-----+-----+-------+