Search code examples
pandaspysparkuser-defined-functionsleast-squares

Convert Pandas best fit function into pyspark


I have been using this function for time series feature creation in Pandas that returns the (OLS?) best-fit slope of a given range of points:

def best_fit(X, Y):
    xbar = sum(X)/len(X)
    ybar = sum(Y)/len(Y)
    n = len(X) 
    numer = sum([xi*yi for xi,yi in zip(X, Y)]) - n * xbar * ybar
    denum = sum([xi**2 for xi in X]) - n * xbar**2
    b = numer / denum
    return b

Here is a simple example showing the results (see final df below):

import pandas as pd
import numpy as np
import random
cols = ['x_vals','y_vals']
df = pd.DataFrame(columns=cols)
for i in range(0,20):
  df.loc[i,'x_vals'] = i
  df.loc[i,'y_vals'] = 0.05 * i**2 + 0.1 * i + random.uniform(-1,1) #some random parabolic points

I then apply the best_fit function to get the slope of the preceding 5 points:

for i,row in df.iterrows():
  if i>=5:
    X = df['x_vals'][i-5:i]
    Y = df['y_vals'][i-5:i]
    df.loc[i,'slope'] = best_fit(X, Y)
df

Which gives me this:

x_vals  y_vals  slope
0   -0.648205   NaN
1   0.282729    NaN
2   0.785474    NaN
3   1.48546     NaN
4   0.408165    NaN
5   1.61244     0.331548
6   2.60868     0.228211
7   3.77621     0.377338
8   4.08937     0.678201
9   4.34625     0.952618
10  5.47554     0.694832
11  7.90902     0.630377
12  8.83912     0.965180
13  9.01195     1.306227
14  11.8244     1.269497
15  13.3199     1.380057
16  15.2751     1.380692
17  15.3959     1.717981
18  18.454      1.621861
19  20.0773     1.533528

I need to get the same slope column out of a pyspark dataframe instead of Pandas, only I am struggling to find a starting point on this (pyspark window?, OLS built-in function?, udf?).


Solution

  • Use Pyspark window, collect the previous 5 col values as list and call the best_fit_udf

    #moodified this function to handle 0 division and size of elements
    def best_fit(X, Y):
        xbar = sum(X)/len(X)
        ybar = sum(Y)/len(Y)
        n = len(X)
        if n < 6 :
           return None
        numer = sum([xi*yi for xi,yi in zip(X, Y)]) - n * xbar * ybar
        denum = sum([xi**2 for xi in X]) - n * xbar**2
        if denum == 0:
           return None
        else:
           return numer / denum
    
    best_fit_udf = udf(best_fit, DoubleType())
    
    cols = ['x_vals','y_vals']
    df = pd.DataFrame(columns=cols)
    for i in range(0,20):
      df.loc[i,'x_vals'] = i
      df.loc[i,'y_vals'] = 0.05 * i**2 + 0.1 * i + random.uniform(-1,1) #some random parabolic points
    
    spark_df = spark.createDataFrame(df)
    
    w = Window.orderBy("x_vals").rowsBetween(-5, 0)
    
    df = spark_df.select("x_vals","y_vals",(F.collect_list('x_vals')).over(w).alias("x_list"), (F.collect_list('y_vals')).over(w).alias("y_list"))
    
    df.withColumn("slope", best_fit_udf('x_list','y_list') ).drop('x_list','y_list').show()
    

    which gives me this

    +------+--------------------+------------------+
    |x_vals|              y_vals|             slope|
    +------+--------------------+------------------+
    |     0|-0.05626232194330516|              null|
    |     1|  1.0626613654187942|              null|
    |     2|-0.18870622421238525|              null|
    |     3|  1.7106172105001147|              null|
    |     4|  1.9398571272258158|              null|
    |     5|  2.3632022124308474| 0.475092382628695|
    |     6|  1.7264493731921893|0.3201115790149247|
    |     7|   3.298712278452215|0.5116552596172641|
    |     8|  4.3179382280764305|0.4707547914949186|
    |     9|    4.00691449276564|0.5077645079970263|
    |    10|   6.085792506183289|0.7563877936316236|
    |    11|   7.272669055040746|1.0223232959178614|
    |    12|    8.70598472345308| 1.085126649123283|
    |    13|  10.141576882812515|1.2686365861314373|
    |    14|  11.170519757896672| 1.411962717827295|
    |    15|  11.999868557507794|1.2199864149871311|
    |    16|   14.86294824152797|1.3960568659909833|
    |    17|  16.698964370210007| 1.570238888844051|
    |    18|   18.71951724368806|1.7810890092953742|
    |    19|  20.428078271618062|1.9509358501665701|
    +------+--------------------+------------------+