I'm using this code to apply a function (funcX
) on my data-frame using a rolling window. The main issue is that the size of this data-frame (data
) is very large, and I'm searching for a faster way to do this task.
import numpy as np
def funcX(x):
x = np.sort(x)
xd = np.delete(x, 25)
med = np.median(xd)
return (np.abs(x - med)).mean() + med
med_out = data.var1.rolling(window = 51, center = True).apply(funcX, raw = True)
The only reason for using this function is that the calculated median is the median after removing the middle value. So it's different with adding .median()
at the end of the rolling window.
To be effective, a window algorithm must link the results of two overlaying windows.
Here, with : med0
the median, med
the median of x \ med0
, xl
elements before med
and xg
elements after med
in the sorted elements, funcX(x)
can be seen as :
<|x-med|> + med = [sum(xg) - sum(xl) - |med0-med|] / windowsize + med
So an idea it to maintain a buffer which represents the sorted current window, sum(xg)
and sum(xl)
. Using Numba just in time compilation, very good performance arise here.
First the buffer management:
init
sorts the first window and compute left(xls
) and right(xgs
) sums.
import numpy as np
import numba
windowsize = 51 #odd, >1
halfsize = windowsize//2
@numba.njit
def init(firstwindow):
buffer = np.sort(firstwindow)
xls = buffer[:halfsize].sum()
xgs = buffer[-halfsize:].sum()
return buffer,xls,xgs
shift
is the linear part. It updates the buffer, maintaining it sorted . np.searchsorted
computes positions of insertion and deletion in O(log(windowsize))
. It's technical since xin<xout
and xout<xin
are not symmetrical situations.
@numba.njit
def shift(buffer,xin,xout):
i_in = np.searchsorted(buffer,xin)
i_out = np.searchsorted(buffer,xout)
if xin <= xout :
buffer[i_in+1:i_out+1] = buffer[i_in:i_out]
buffer[i_in] = xin
else:
buffer[i_out:i_in-1] = buffer[i_out+1:i_in]
buffer[i_in-1] = xin
return i_in, i_out
update
updates the buffer and the sums of left and right parts. It's technical since xin<xout
and xout<xin
are not symmetrical situations.
@numba.njit
def update(buffer,xls,xgs,xin,xout):
xl,x0,xg = buffer[halfsize-1:halfsize+2]
i_in,i_out = shift(buffer,xin,xout)
if i_out < halfsize:
xls -= xout
if i_in <= halfsize:
xls += xin
else:
xls += x0
elif i_in < halfsize:
xls += xin - xl
if i_out > halfsize:
xgs -= xout
if i_in > halfsize:
xgs += xin
else:
xgs += x0
elif i_in > halfsize+1:
xgs += xin - xg
return buffer, xls, xgs
func
is equivalent to the original funcX
on buffer. O(1)
.
@numba.njit
def func(buffer,xls,xgs):
med0 = buffer[halfsize]
med = (buffer[halfsize-1] + buffer[halfsize+1])/2
if med0 > med:
return (xgs-xls+med0-med) / windowsize + med
else:
return (xgs-xls+med-med0) / windowsize + med
med
is the global function. O(data.size * windowsize)
.
@numba.njit
def med(data):
res = np.full_like(data, np.nan)
state = init(data[:windowsize])
res[halfsize] = func(*state)
for i in range(windowsize, data.size):
xin,xout = data[i], data[i - windowsize]
state = update(*state, xin, xout)
res[i-halfsize] = func(*state)
return res
Performance :
import pandas
data=pandas.DataFrame(np.random.rand(10**5))
%time res1=data[0].rolling(window = windowsize, center = True).apply(funcX, raw = True)
Wall time: 10.8 s
res2=med(data[0].values)
np.allclose((res1-res2)[halfsize:-halfsize],0)
Out[112]: True
%timeit res2=med(data[0].values)
40.4 ms ± 462 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
it's ~ 250X faster, with window size = 51. An hour becomes 15 seconds.