Search code examples
haskellstate-monad

Haskell State Monad Example


I'm experimenting with Haskell's Control.Monad.State by trying to iterate through a list of either strings or integers, counting them, and replacing string entries with the integer 0. I have managed to do the counting part, but failed in creating the replaced list. Here is my code which correctly prints [3,6] to the screen. How can I make it create the desired list [6,0,3,8,0,2,9,1,0]?

module Main( main ) where

import Control.Monad.State

l = [
    Right 6,
    Left "AAA",
    Right 3,
    Right 8,
    Left "CCC",
    Right 2,
    Right 9,
    Right 1,
    Left "D"]

scanList :: [ Either String Int ] -> State (Int,Int) [ Int ]
scanList [    ] = do
    (ns,ni) <- get
    return (ns:[ni])
scanList (x:xs) = do
    (ns,ni) <- get
    case x of
        Left  _ -> put (ns+1,ni)
        Right _ -> put (ns,ni+1)
    case x of
        Left  _ -> scanList xs -- [0] ++ scanList xs not working ...
        Right i -> scanList xs -- [i] ++ scanList xs not working ...

startState = (0,0)

main = do
    print $ evalState (scanList l) startState

Solution

  • [0] ++ scanList xs doesn't work because scanList xs is not a list, but a State (Int,Int) [Int]. To fix that, you will need to use fmap/<$>.

    You will also need to change the base case to not make the state value be the return value.

    scanList :: [Either String Int] -> State (Int, Int) [Int]
    scanList []     = return []
    scanList (x:xs) = do
        (ns,ni) <- get
        case x of
            Left  _ -> put (ns+1, ni)
            Right _ -> put (ns, ni+1)
        case x of
            Left  _ -> (0 :) <$> scanList xs
            Right i -> (i :) <$> scanList xs
    

    To further simplify the code, however, it would be good to use mapM/traverse and state to remove most of the boilerplate of the recursion and get/put syntax.

    scanList :: [Either String Int] -> State (Int, Int) [Int]
    scanList = mapM $ \x -> state $ \(ns, ni) -> case x of
        Left  _ -> (0, (ns+1, ni))
        Right i -> (i, (ns, ni+1))