Search code examples
python-3.xapache-sparkpysparkwindow-functions

Pyspark StandardScaler over a Window


I want to use the standardscaler pyspark.ml.feature.StandardScaler over window of my data.

df4=spark.createDataFrame(
    [
        (1,1, 'X', 'a'),
        (2,1, 'X', 'a'),
        (3,9, 'X', 'b'),
        (5,1, 'X', 'b'),
        (6,2, 'X', 'c'),
        (7,2, 'X', 'c'),
        (8,10, 'Y', 'a'),
        (9,45, 'Y', 'a'),
        (10,3, 'Y', 'a'),
        (11,3, 'Y', 'b'),
        (12,6, 'Y', 'b'),
        (13,19,'Y', 'b')
    ],
    ['id','feature', 'txt', 'cat'] 
)

w = Window().partitionBy(..)

I can do this over the whole dataframe by calling the .fit& .transform methods. But not on the w variable which we use generally like F.col('feature') - F.mean('feature').over(w).

I can transform all my windowed/grouped data into separate columns, put it into a dataframe and then apply StandardScaler over it and transform back to 1D. Is there any other method ? The ultimate goal is to try different scalers including pyspark.ml.feature.RobustScaler.


Solution

  • I eventually had to write my own scaler class. Using the pyspark StandardScaler in the above problem is not suitable as we all know it is more efficient for end to end series transformations. Nonetheless I came up with my own scaler. It does not really use Window from pyspark but i achieve the functionality using groupby.

    class StandardScaler:
        
        tol = 0.000001
        
        def __init__(self, colsTotransform, groupbyCol='txt', orderBycol='id'):
            self.colsTotransform = colsTotransform
            self.groupbyCol=groupbyCol
            self.orderBycol=orderBycol
        
        def __tempNames__(self):
            return [(f"{colname}_transformed",colname) for colname in self.colsTotransform]
        
        def fit(self, df):
            funcs = [(F.mean(name), F.stddev(name)) for name in self.colsTotransform]
            exprs = [ff for tup in funcs for ff in tup]
            self.stats = df.groupBy([self.groupbyCol]).agg(*exprs)
        
        def __transformOne__(self, df_with_stats, newName, colName):
            return df_with_stats\
                    .withColumn(newName, 
                                (F.col(colName)-F.col(f'avg({colName})'))/(F.col(f'stddev_samp({colName})')+self.tol))\
                    .drop(colName)\
                    .withColumnRenamed(newName, colName)
    
        def transform(self, df):
            df_with_stats = df.join(self.stats, on=self.groupbyCol, how='inner').orderBy(self.orderBycol)
            return reduce(lambda df_with_stats, kv: self.__transformOne__(df_with_stats, *kv), 
                           self.__tempNames__(), df_with_stats)[df.columns]
        
       
    

    Usage :

    ss = StandardScaler(colsTotransform=['feature'],groupbyCol='txt',orderbyCol='id')
    ss.fit(df4)
    ss.stats.show()
    
    +---+------------------+--------------------+
    |txt|      avg(feature)|stddev_samp(feature)|
    +---+------------------+--------------------+
    |  Y|14.333333333333334|  16.169930941926335|
    |  X|2.6666666666666665|  3.1411250638372654|
    +---+------------------+--------------------+
    
    df4.show()
    
    +---+-------+---+---+
    | id|feature|txt|cat|
    +---+-------+---+---+
    |  1|      1|  X|  a|
    |  2|      1|  X|  a|
    |  3|      9|  X|  b|
    |  5|      1|  X|  b|
    |  6|      2|  X|  c|
    |  7|      2|  X|  c|
    |  8|     10|  Y|  a|
    |  9|     45|  Y|  a|
    | 10|      3|  Y|  a|
    | 11|      3|  Y|  b|
    | 12|      6|  Y|  b|
    | 13|     19|  Y|  b|
    +---+-------+---+---+
    
    ss.transform(df4).show()
    +---+--------------------+---+---+
    | id|             feature|txt|cat|
    +---+--------------------+---+---+
    |  1|  -0.530595281053646|  X|  a|
    |  2|  -0.530595281053646|  X|  a|
    |  3|  2.0162620680038548|  X|  b|
    |  5|  -0.530595281053646|  X|  b|
    |  6|-0.21223811242145835|  X|  c|
    |  7|-0.21223811242145835|  X|  c|
    |  8| -0.2679871102053074|  Y|  a|
    |  9|  1.8965241645298676|  Y|  a|
    | 10| -0.7008893651523425|  Y|  a|
    | 11| -0.7008893651523425|  Y|  b|
    | 12| -0.5153598273178989|  Y|  b|
    | 13|  0.2886015032980233|  Y|  b|
    +---+--------------------+---+---+