How can I improve the the following rolling sum implementation?
type Buffer = State BufferState (Maybe Double)
type BufferState = ( [Double] , Int, Int )
-- circular buffer
buff :: Double -> Buffer
buff newVal = do
( list, ptr, len) <- get
-- if the list is not full yet just accumulate the new value
if length list < len
then do
put ( newVal : list , ptr, len)
return Nothing
else do
let nptr = (ptr - 1) `mod` len
(as,(v:bs)) = splitAt ptr list
nlist = as ++ (newVal : bs)
put (nlist, nptr, len)
return $ Just v
-- create intial state for circular buffer
initBuff l = ( [] , l-1 , l)
-- use the circular buffer to calculate a rolling sum
rollSum :: Double -> State (Double,BufferState) (Maybe Double)
rollSum newVal = do
(acc,bState) <- get
let (lv , bState' ) = runState (buff newVal) bState
acc' = acc + newVal
-- subtract the old value if the circular buffer is full
case lv of
Just x -> put ( acc' - x , bState') >> (return $ Just (acc' - x))
Nothing -> put ( acc' , bState') >> return Nothing
test :: (Double,BufferState) -> [Double] -> [Maybe Double] -> [Maybe Double]
test state [] acc = acc
test state (x:xs) acc =
let (a,s) = runState (rollSum x) state
in test s xs (a:acc)
main :: IO()
main = print $ test (0,initBuff 3) [1,1,1,2,2,0] []
Buffer uses the State monad to implement a circular buffer. rollSum uses the State monad again to keep track of the rolling sum value and the state of the circular buffer.
Thanks!
EDIT
I forgot to mention I am using a circular buffer as I intend to use this code on-line and process updates as they arrive - hence the need to record state. Something like
newRollingSum = update rollingSum newValue
I haven't managed to decipher all of your code, but here is the plan I would take for solving this problem. First, an English description of the plan:
n
starting at each index.
n
.n-1
of these, which will be too short.This was the first idea I had; for windows of length three it's an okay approach because step 2
is cheap on such a short list. For longer windows, you may want an alternate approach, which I will discuss below; but this approach has the benefit that it generalizes smoothly to functions other than sum
. The code might look like this:
import Data.List
rollingSums n xs
= map sum -- add up the entries
. zipWith (flip const) (drop (n-1) xs) -- drop the last n-1
. map (take n) -- truncate long windows
. tails -- make arbitrarily long windows
$ xs
If you're familiar with the "equational reasoning" approach to optimization, you might spot a first place we can improve the performance of this function: by swapping the first map
and zipWith
, we can produce a function with the same behavior but with a map f . map g
subterm, which can be replaced by map (f . g)
to get slightly less allocation.
Unfortunately, for large n
, this adds n
numbers together in the inner loop; we would prefer to simply add the value at the "front" of the window and subtract the one at the "back". So we need to get trickier. Here's a new idea: we'll traverse the list twice in parallel, n
positions apart. Then we'll use a simple function for getting the rolling sum (of unbounded window length) of prefixes of a list, namely, scanl (+)
, to convert this traversal into the actual sums we're interested in.
rollingSumsEfficient n xs = scanl (+) firstSum deltas where
firstSum = sum (take n xs)
deltas = zipWith (-) (drop n xs) xs -- front - back
There's one twist, which is that scanl
never returns an empty list. So if it's important that you be able to handle short lists, you'll want another equation that checks for these. Don't use length
, as that forces the entire input list into memory before starting the computation -- a potentially lethal performance mistake. Instead add a line like this above the previous definition:
rollingSumsEfficient n xs | null (drop (n-1) xs) = []
We can try these two out in ghci. You'll notice that they do not quite have the same behavior as yours:
*Main> rollingSums 3 [10^n | n <- [0..5]]
[111,1110,11100,111000]
*Main> rollingSumsEfficient 3 [10^n | n <- [0..5]]
[111,1110,11100,111000]
On the other hand, the implementations are considerably more concise and are fully lazy in the sense that they work on infinite lists:
*Main> take 5 . rollingSums 10 $ [1..]
[55,65,75,85,95]
*Main> take 5 . rollingSumsEfficient 10 $ [1..]
[55,65,75,85,95]