Search code examples
vectorpysparkpipelinefeature-scalingstandardization

How to implement PySpark StandardScaler on subset of columns?


I want to use pyspark StandardScaler on 6 out of 10 columns in my dataframe. This will be part of a pipeline.

The inputCol parameter seems to expect a vector, which I can pass in after using VectorAssembler on all my features, but this scales all 10 features. I don’t want to scale the other 4 features because they are binary and I want unstandardized coefficients for them.

Am I supposed to use vector assembler on the 6 features, scale them, then use vector assembler again on this scaled features vector and the remaining 4 features? I would end up with a vector within a vector and I’m not sure this will work.

What’s the right way to do this? An example is appreciated.


Solution

  • You can do this by using VectorAssembler. They key is you have to extract the columns from the assembler output. See the code below for a working example,

    from pyspark.ml.feature import MinMaxScaler, StandardScaler
    from pyspark.ml.feature import VectorAssembler
    import pandas as pd
    import numpy as np
    import random
    
    df = pd.DataFrame()
    df['a'] = random.sample(range(100), 10)
    df['b'] = random.sample(range(100), 10)
    df['c'] = random.sample(range(100), 10)
    df['d'] = random.sample(range(100), 10)
    df['e'] = random.sample(range(100), 10)
    
    sdf = sc.createDataFrame(df)
    
    sdf.show()
    
    +---+---+---+---+---+
    |  a|  b|  c|  d|  e|
    +---+---+---+---+---+
    | 51| 13|  6|  5| 26|
    | 18| 29| 19| 81| 28|
    | 34|  1| 36| 57| 87|
    | 56| 86| 51| 52| 48|
    | 36| 49| 33| 15| 54|
    | 87| 53| 47| 89| 85|
    |  7| 14| 55| 13| 98|
    | 70| 50| 32| 39| 58|
    | 80| 20| 25| 54| 37|
    | 40| 33| 44| 83| 27|
    +---+---+---+---+---+
    
    cols_to_scale = ['c', 'd', 'e']
    cols_to_keep_unscaled = ['a', 'b']
    
    scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")
    assembler = VectorAssembler().setInputCols(cols_to_scale).setOutputCol("features")
    sdf_transformed = assembler.transform(sdf)
    scaler_model = scaler.fit(sdf_transformed.select("features"))
    sdf_scaled = scaler_model.transform(sdf_transformed)
    
    sdf_scaled.show()
    
    +---+---+---+---+---+----------------+--------------------+
    |  a|  b|  c|  d|  e|        features|      scaledFeatures|
    +---+---+---+---+---+----------------+--------------------+
    | 51| 13|  6|  5| 26|  [6.0,5.0,26.0]|[0.39358015146628...|
    | 18| 29| 19| 81| 28|[19.0,81.0,28.0]|[1.24633714630991...|
    | 34|  1| 36| 57| 87|[36.0,57.0,87.0]|[2.36148090879773...|
    | 56| 86| 51| 52| 48|[51.0,52.0,48.0]|[3.34543128746345...|
    | 36| 49| 33| 15| 54|[33.0,15.0,54.0]|[2.16469083306459...|
    | 87| 53| 47| 89| 85|[47.0,89.0,85.0]|[3.08304451981926...|
    |  7| 14| 55| 13| 98|[55.0,13.0,98.0]|[3.60781805510765...|
    | 70| 50| 32| 39| 58|[32.0,39.0,58.0]|[2.09909414115354...|
    | 80| 20| 25| 54| 37|[25.0,54.0,37.0]|[1.63991729777620...|
    | 40| 33| 44| 83| 27|[44.0,83.0,27.0]|[2.88625444408612...|
    +---+---+---+---+---+----------------+--------------------+
    
    # Function just to convert to help build data frame
    def extract(row):
      return (row.a, row.b,) + tuple(row.scaledFeatures.toArray().tolist())
    
    sdf_scaled = sdf_scaled.select(*cols_to_keep_unscaled, "scaledFeatures").rdd \
            .map(extract).toDF(cols_to_keep_unscaled + cols_to_scale)
      
      
    sdf_scaled.show()
    
    
    +---+---+------------------+-------------------+------------------+
    |  a|  b|                 c|                  d|                 e|
    +---+---+------------------+-------------------+------------------+
    | 51| 13|0.3935801514662892|0.16399957083190683|0.9667572801316145|
    | 18| 29| 1.246337146309916|  2.656793047476891|1.0411232247571234|
    | 34|  1|2.3614809087977355| 1.8695951074837378|3.2349185912096337|
    | 56| 86|3.3454312874634584| 1.7055955366518312|1.7847826710122114|
    | 36| 49| 2.164690833064591|0.49199871249572047| 2.007880504888738|
    | 87| 53| 3.083044519819266| 2.9191923608079415|3.1605526465841245|
    |  7| 14|3.6078180551076513| 0.4263988841629578| 3.643931286649932|
    | 70| 50|2.0990941411535426| 1.2791966524888734|2.1566123941397555|
    | 80| 20| 1.639917297776205| 1.7711953649845937| 1.375769975571913|
    | 40| 33|2.8862544440861213| 2.7223928758096534| 1.003940252444369|
    +---+---+------------------+-------------------+------------------+