Search code examples
haskellmonadsstate-monad

How does 'get' actually /get/ the initial state in Haskell?


I have a function:

test :: String -> State String String
test x = 
    get >>= \test ->
    let test' = x ++ test in
    put test' >>
    get >>= \test2 -> put (test2 ++ x) >>
    return "test"

I can pretty much understand what goes on throughout this function, and I'm starting to get the hang of monads. What I don't understand is how, when I run this:

runState (test "testy") "testtest"

the 'get' function in 'test' somehow gets the initial state "testtest". Can someone break this down and explain it to me?

I appreciate any responses!


Solution

  • I was originally going to post this as a comment, but decided to expound a bit more.

    Strictly speaking, get doesn't "take" an argument. I think a lot of what is going on is masked by what you aren't seeing--the instance definitions of the State monad.

    get is actually a method of the MonadState class. The State monad is an instance of MonadState, providing the following definition of get:

    get = State $ \s -> (s,s)
    

    In other words, get just returns a very basic State monad (remembering that a monad can be thought of as a "wrapper" for a computation), where any input s into the computation will return a pair of s as the result.

    The next thing we need to look at is >>=, which State defines thusly:

    m >>= k  = State $ \s -> let
        (a, s') = runState m s
        in runState (k a) s'
    

    So, >>= is going to yield a new computation, which won't be computed until it gets an initial state (this is true of all State computations when they're in their "wrapped" form). The result of this new computation is achieved by applying whatever is on the right side of the >>= to the result of running the computation that was on the left side. (That's a pretty confusing sentence that may require an additional reading or two.)

    I've found it quite useful to "desugar" everything that's going on. Doing so takes a lot more typing, but should make the answer to your question (where get is getting from) very clear. Note that the following should be considered psuedocode...

    test x =
        State $ \s -> let
            (a,s') = runState (State (\s -> (s,s))) s  --substituting above defn. of 'get'
            in runState (rightSide a) s'
            where 
              rightSide test = 
                let test' = x ++ test in
                State $ \s2 -> let
                (a2, s2') = runState (State $ \_ -> ((), test')) s2  -- defn. of 'put'
                in runState (rightSide2 a2) s2'
              rightSide2 _ =
                -- etc...
    

    That should make it obvious that the end result of our function is a new State computation that will need an initial value (s) to make the rest of the stuff happen. You supplied s as "testtest" with your runState call. If you substitute "testtest" for s in the above pseudocode, you'll see that the first thing that happens is we run get with "testtest" as the 'initial state'. This yields ("testtest", "testtest") and so on.

    So that's where get gets your initial state "testtest". Hope this helps!