Search code examples
listhaskellmonads

Using monad for cumulative sum on list of pairs (haskell)


I have a list of pairs structure [("oct",1),("nov",1),("dec",1)]

I want to calculate sum within pairs: [("oct",1),("nov",2),("dec",3)]. I think this is a good case for monadic implementation, but can't figure out how to preserve containers.

I tried to make function on lists (I already know about scanl1, just showing effort here :)

csm (x:[]) = [x]
csm (x:y:xs) = x : csm ((x + y) : xs)

And then try something like:

sumOnPairs l = do
  pair <- l
  return (csm (snd pair))

My solution is not working, please point me in the right direction


Solution

  • The list monad models nondetermism: do the same thing to each element of the list, then collect the results in a new list.

    For the type of sequential traversal you want (do something to an element, then use the result to do something to the next element, etc), you can use the State monad to do something like

    import Control.Monad.Trans.State
    import Data.Bifunctor
    
    type Pair = (String, Int)
    
    foo :: Pair -> State Pair Pair
    foo (month, y) = do
       -- bimap f g (x,y) == (f x, g y)
       -- The new month replaces the old month,
       -- and y is added to the sum.
       modify (bimap (const month) (+y))
       -- Return a snapshot of the state
       get
       
    
    sumOnPairs :: [Pair] -> [Pair]
    sumOnPairs = flip evalState ("", 0) . traverse foo
    

    At each step, the new state is the current month and the sum of the old state's number and the current number. traverse accumulates those states in a list while traversing the original list.

    > sumOnPairs [("oct",1),("nov",1),("dec",1)]
    [("oct",1),("nov",2),("dec",3)]
    

    You can also keep only the sum in the state, rather than a month that just gets replaced and the sum.

    foo' :: Pair -> State Int Pair
    foo' x@(_, count) = do
       modify (+ count)
       fmap (<$ x) get
    
    sumOnPairs' :: [Pair] -> [Pair]
    sumOnPairs' = flip evalState 0 . traverse bar
    

    In this case, only the current sum is kept in the state; the new pair is generated by using the <$ operator, which the Functor instance of (,) String to replace the number in the current pair with the sum in the state

    > 6 <$ ("foo", 3)
    ("foo", 6)
    

    I think using Data.Functor.($>) (the flipped version of <$) might be more readable, if you choose this route.

    foo' x@(_, count) = do
       modify (+ count)
       fmap (x $>) get
    

    Visually, it's more similar to what you could write if you didn't need to map over get: x $> y == (fst x, y).