Search code examples
loopsapache-sparkpysparklagpow

How to loop multiple decay rate in multiple columns in pyspark


I try to pass a list in parameter of my function.

My list is composed of different coefficients to be apply to lag numerous columns.

However, I only manage to generate the columns in my dataframe for the first value of my list.

this is my actual result :

"col1", "col2", "col1_0.2", "col2_0.2"

what is expected :

"col1", "col2", "col1_0.2", "col2_0.2", "col1_0.4", "col2_0.4", "col1_0.6", "col2_0.6"

I must have missed a few things in my loop ?

selected_col = col_selector(df, ["col1", "col2"])


w = Window.partitionBy("student").orderBy("date")
coef = (.1,.4,.6)

def custom_coef(col, w, coef):
    for x in coef:
        return sum(
            pow(i, x) * F.lag(F.col(col), i, default=0).over(w)
            for i in range(1)
        ).alias(col +"_"+str(x))

new_df = df.select(
    F.col("*"),
    *[custom_coef(col, w, coef) for col in selected_col]
)

thanks


Solution

  • The return statement in the custom_coef function ends the function after the first execution of loop over coef. This means that custom_coef will always return the first column definition, and this is the column definition for coef 0.1. As the function is called once per column in selected_col you get the result that you are describing.

    One way to fix the problem without changing the structure of the code is to replace return with yield. This way custom_coef creates one generator per element of selected_col. These generators can be chained with itertools.chain and this result can be used as parameter of the select statement:

    def custom_coef(col, w, coef):
        for x in coef:
            yield sum(  #use yield instead of return
                pow(i, x) * F.lag(F.col(col), i, default=0).over(w)
                for i in range(1)
            ).alias(col +"_"+str(x))
    
    new_df = df.select(
        F.col("*"),
        *chain(*[custom_coef(col, w, coef) for col in selected_col]) #chain the generators
    )
    new_df.show()