Search code examples
pythonpysparklinear-regression

Linear Regression over Window in PySpark


I want to perform a Linear Regression over a Window in PySpark. I have a timeseries and I want the slope of that timeseries for each person (identified by an ID) in a dataset looking 12 months back.

My idea was to do like this:

sliding_window = Window.partitionBy('ID').orderBy('date').rowsBetween(-12, 0)
df=df.withColumn("date_integer", F.unix_timestamp(df['date']))
assembler = VectorAssembler(inputCols=['date_integer'], outputCol='features')
vector_df = assembler.transform(df)
lr = LinearRegression(featuresCol='features', labelCol='series')
df = df.withColumn('slope_window', lr.fit(vector_df).coefficients[0].over(sliding_window))

However, after 15 minutes of execution I get this error:

AttributeError: 'numpy.float64' object has no attribute 'over'

Any advice?


Solution

  • Given that only a small number of rows is required for each linear regression, you can switch to sklearn for the actual calculation and run these regressions in parallel using Spark's udf.

    Generate some testdata

    from datetime import date, timedelta
    import random
    from pyspark.sql import functions as F
    from pyspark.sql import types as T
    
    startdate = date.today() - timedelta(days=365)
    data = [[id, startdate + timedelta(days=d), id * (d + random.random() - 0.5) ] for id in range(1,5) for d in range(1,10)]
    df=spark.createDataFrame(data, ['id', 'date', 'value']) \
        .withColumn("date_integer", F.unix_timestamp('date') / 60  / 60 / 24)
    
    +---+----------+------------------+------------------+
    | id|      date|             value|      date_integer|
    +---+----------+------------------+------------------+
    |  1|2022-03-31|0.6168826236941936|19081.916666666668|
    |  1|2022-04-01|1.5357778308903614|19082.916666666668|
    |  1|2022-04-08| 9.208155706805703|19089.916666666668|
    |  2|2022-03-31|2.3609080691388877|19081.916666666668|
    |  2|2022-04-08|17.476392355447565|19089.916666666668|
    |  3|2022-03-31|2.4292967172497213|19081.916666666668|
    [...]
    

    Collect the input values for the linear regression in array columns

    from pyspark.sql import Window
    
    sliding_window = Window.partitionBy('id').orderBy('date').rowsBetween(-12,0)
    
    df2 = df.withColumn('last_12_dates', F.collect_list('date_integer').over(sliding_window))\
      .withColumn('last_12_values', F.collect_list('value').over(sliding_window))
    

    Define the udf that runs the linear regression

    @F.udf(returnType=T.DoubleType())
    def linear_reg(dates, values):
        import numpy as np
        from sklearn.linear_model import LinearRegression
        x = np.reshape(dates, (-1,1))
        lr = LinearRegression().fit(x, values)
        return float(lr.coef_[0])
    

    And finally execute the calculation

    df2.withColumn('slope_window', linear_reg(F.col('last_12_dates'), F.col('last_12_values'))) \
      .drop('date_integer', 'last_12_dates', 'last_12_values') \
      .show(truncate=False, n=50)
    

    Result:

    ---+----------+------------------+------------------+
    |id |date      |value             |slope_window      |
    +---+----------+------------------+------------------+
    |1  |2022-03-31|0.6168826236941936|0.0               |
    |1  |2022-04-01|1.5357778308903614|0.9188952071961676|
    |1  |2022-04-08|9.208155706805703 |1.1178376705639304|
    |2  |2022-03-31|2.3609080691388877|0.0               |
    |2  |2022-04-01|4.069062243789487 |1.7081541746505986|
    |2  |2022-04-07|15.013921583339329|1.8849117914166331|
    |2  |2022-04-08|17.476392355447565|1.8662925262252443|
    |3  |2022-03-31|2.4292967172497213|0.0               |
    |3  |2022-04-01|5.343956838751345 |2.914660121501623 |
    |3  |2022-04-07|23.05250860814219 |2.938836071122241 |
    |3  |2022-04-08|26.5399792695849  |2.942384338095864 |
    |4  |2022-03-31|5.088327258258278 |0.0               |
    |4  |2022-04-01|7.4632476763570015|2.3749204180987222|
    |4  |2022-04-07|30.56238059275477 |3.900228891002952 |
    |4  |2022-04-08|34.760905276327996|3.8133127278310197|
    +---+----------+------------------+------------------+