Search code examples
haskellstate-monad

How can this haskell rolling sum implementation be improved?


How can I improve the the following rolling sum implementation?

type Buffer  = State BufferState (Maybe Double)
type BufferState = ( [Double] , Int, Int )

-- circular buffer    
buff :: Double -> Buffer 
buff newVal = do
  ( list, ptr, len) <- get
  -- if the list is not full yet just accumulate the new value
  if length list < len
    then do
      put ( newVal : list , ptr, len)
      return Nothing
    else do
      let nptr = (ptr - 1) `mod` len
          (as,(v:bs)) = splitAt ptr list
          nlist = as ++ (newVal : bs)
      put (nlist, nptr, len)
      return $ Just v

-- create intial state for circular buffer
initBuff l = ( [] , l-1 , l)

-- use the circular buffer to calculate a rolling sum
rollSum :: Double -> State (Double,BufferState) (Maybe Double)
rollSum newVal = do
  (acc,bState) <- get
  let (lv , bState' )  = runState (buff newVal) bState
      acc' = acc + newVal
  -- subtract the old value if the circular buffer is full
  case lv of
    Just x  -> put ( acc' - x , bState') >> (return $ Just (acc' - x))
    Nothing -> put ( acc' , bState')     >> return Nothing

test :: (Double,BufferState) -> [Double] -> [Maybe Double] -> [Maybe Double]
test state [] acc = acc
test state (x:xs) acc = 
  let (a,s) = runState (rollSum x) state
  in test s xs (a:acc)

main :: IO()
main = print $ test (0,initBuff 3) [1,1,1,2,2,0] []

Buffer uses the State monad to implement a circular buffer. rollSum uses the State monad again to keep track of the rolling sum value and the state of the circular buffer.

  • How could I make this more elegant?
  • I'd like to implement other functions like rolling average or a difference, what could I do to make this easy?

Thanks!

EDIT

I forgot to mention I am using a circular buffer as I intend to use this code on-line and process updates as they arrive - hence the need to record state. Something like

newRollingSum = update rollingSum newValue 

Solution

  • I haven't managed to decipher all of your code, but here is the plan I would take for solving this problem. First, an English description of the plan:

    1. We need windows into the list of length n starting at each index.
      1. Make windows of arbitrary length.
      2. Truncate long windows to length n.
      3. Drop the last n-1 of these, which will be too short.
    2. For each window, add up the entries.

    This was the first idea I had; for windows of length three it's an okay approach because step 2 is cheap on such a short list. For longer windows, you may want an alternate approach, which I will discuss below; but this approach has the benefit that it generalizes smoothly to functions other than sum. The code might look like this:

    import Data.List
    
    rollingSums n xs
        = map sum                              -- add up the entries
        . zipWith (flip const) (drop (n-1) xs) -- drop the last n-1
        . map (take n)                         -- truncate long windows
        . tails                                -- make arbitrarily long windows
        $ xs
    

    If you're familiar with the "equational reasoning" approach to optimization, you might spot a first place we can improve the performance of this function: by swapping the first map and zipWith, we can produce a function with the same behavior but with a map f . map g subterm, which can be replaced by map (f . g) to get slightly less allocation.

    Unfortunately, for large n, this adds n numbers together in the inner loop; we would prefer to simply add the value at the "front" of the window and subtract the one at the "back". So we need to get trickier. Here's a new idea: we'll traverse the list twice in parallel, n positions apart. Then we'll use a simple function for getting the rolling sum (of unbounded window length) of prefixes of a list, namely, scanl (+), to convert this traversal into the actual sums we're interested in.

    rollingSumsEfficient n xs = scanl (+) firstSum deltas where
        firstSum = sum (take n xs)
        deltas   = zipWith (-) (drop n xs) xs -- front - back
    

    There's one twist, which is that scanl never returns an empty list. So if it's important that you be able to handle short lists, you'll want another equation that checks for these. Don't use length, as that forces the entire input list into memory before starting the computation -- a potentially lethal performance mistake. Instead add a line like this above the previous definition:

    rollingSumsEfficient n xs | null (drop (n-1) xs) = []
    

    We can try these two out in ghci. You'll notice that they do not quite have the same behavior as yours:

    *Main> rollingSums 3 [10^n | n <- [0..5]]
    [111,1110,11100,111000]
    *Main> rollingSumsEfficient 3 [10^n | n <- [0..5]]
    [111,1110,11100,111000]
    

    On the other hand, the implementations are considerably more concise and are fully lazy in the sense that they work on infinite lists:

    *Main> take 5 . rollingSums 10 $ [1..]
    [55,65,75,85,95]
    *Main> take 5 . rollingSumsEfficient 10 $ [1..]
    [55,65,75,85,95]