Search code examples
haskellmonadsstate-monad

confusion over the passing of State monad in Haskell


In Haskell the State is monad is passed around to extract and store state. And in the two following examples, both pass the State monad using >>, and a close verification (by function inlining and reduction) confirms that the state is indeed passed to the next step.

Yet this seems not very intuitive. So does this mean when I want to pass the State monad I just need >> (or the >>= and lambda expression \s -> a where s is not free in a)? Can anyone provide an intuitive explanation for this fact without bothering to reduce the function?

-- the first example
tick :: State Int Int 
tick = get >>= \n ->
   put (n+1) >>
   return n

-- the second example
type GameValue = Int 
type GameState = (Bool, Int)

playGame' :: String -> State GameState GameValue 
playGame' []      = get >>= \(on, score) -> return score
playGame' (x: xs) = get >>= \(on, score) ->
    case x of
        'a' | on -> put (on, score+1)
        'b' | on -> put (on, score-1)
        'c'      -> put (not on, score)
        _        -> put (on, score) 
    >> playGame xs 

Thanks a lot!


Solution

  • It really boils down to understanding that state is isomorphic to s -> (a, s). So any value "wrapped" in a monadic action is a result of applying a transformation to some state s (a stateful computation producing a).

    Passing a state between two stateful computations

    f :: a -> State s b
    g :: b -> State s c
    

    corresponds to composing them with >=>

    f >=> g
    

    or using >>=

    \a -> f a >>= g
    

    the result here is

    a -> State s c
    

    it is a stateful action that transforms some underlying state s in some way, it is allowed access to some a and it produces some c. So the entire transformation is allowed to depend on a and the value c is allowed to depend on some state s. This is exactly what you would want to express a stateful computation. The neat thing (and the sole purpose of expressing this machinery as a monad) is that you do not have to bother with passing the state around. But to understand how it is done, please refer to the definition of >>= on hackage), just ignore for a moment that it is a transformer rather than a final monad).

    m >>= k  = StateT $ \ s -> do
        ~(a, s') <- runStateT m s
        runStateT (k a) s'
    

    you can disregard the wrapping and unwrapping using StateT and runStateT, here m is in form s -> (a, s), k is of form a -> (s -> (b, s)), and you wish to produce a stateful transformation s -> (b, s). So the result is going to be a function of s, to produce b you can use k but you need a first, how do you produce a? you can take m and apply it to the state s, you get a modified state s' from the first monadic action m, and you pass that state into (k a) (which is of type s -> (b, s)). It is here that the state s has passed through m to become s' and be passed to k to become some final s''.

    For you as a user of this mechanism, this remains hidden, and that is the neat thing about monads. If you want a state to evolve along some computation, you build your computation from small steps that you express as State-actions and you let do-notation or bind (>>=) to do the chaining/passing.

    The sole difference between >>= and >> is that you either care or don't care about the non-state result.

    a >> b
    

    is in fact equivalent to

    a >>= \_ -> b
    

    so what ever value gets output by the action a, you throw it away (keeping only the modified state) and continue (pass the state along) with the other action b.


    Regarding you examples

    tick :: State Int Int 
    tick = get >>= \n ->
        put (n+1) >>
        return n
    

    you can rewrite it in do-notation as

    tick = do
        n <- get
        put (n + 1)
        return n
    

    while the first way of writing it makes it maybe more explicit what is passed how, the second way nicely shows how you do not have to care about it.

    1. First get the current state and expose it (get :: s -> (s, s) in a simplified setting), the <- says that you do care about the value and you do not want to throw it away, the underlying state is also passed in the background without a change (that is how get works).

    2. Then put :: s -> (s -> ((), s)), which is equivalent after dropping unnecessary parens to put :: s -> s -> ((), s), takes a value to replace the current state with (the first argument), and produces a stateful action whose result is the uninteresting value () which you drop (because you do not use <- or because you use >> instead of >>=). Due to put the underlying state has changed to n + 1 and as such it is passed on.

    3. return does nothing to the underlying state, it only returns its argument.

    To summarise, tick starts with some initial value s it updates it to s+1 internally and outputs s on the side.

    The other example works exactly the same way, >> is only used there to throw away the () produced by put. But state gets passed around all the time.