Search code examples
haskellmonadsstate-monad

understanding StateMonad


I'm a haskell beginner and trying to understand this definition of a StateMonad, specifically the bind operation. It's taken from Generalising Monads to Arrows page 4.

instance Monad (StateMonad s) where
     return a = SM (\s -> (a, s))
     x >>= f = SM (\s -> let
                             SM x' = x
                             (a, s') = x' s
                             SM f' = f a
                             (b, s'') = f' s'
                         in (b, s''))

Solution

  • First you need to understand the type of >>=; I'll assume you do since it's on page 2 of that paper and you got past that.

    The definition for bind may be easier to understand if we define runState.

    newtype SM s a = SM (s -> (a, s))
    
    runState ::  SM a  -> s -> (a, s)
    runState    (SM f)    s =  f s
    -- this is equivalent to
    -- runState (SM f) =  f
    

    runState runs a state monad by extracting the function f that transforms the state and applying it to the initial state s. The function f returns a tuple of type (a, s). The tuple contains the value (of type a) that depended on the state, and a new state (of type s). The following are equivalent

    let (a, s') = runState x s
    in ...
    
    let SM x' = x
        (a, s') = x' s
    in ...
    

    Both of these extract the function x' for how the state is transformed from x, then apply it to an initial state s. The resulting tuple (a, s') holds the state-dependent value a, and the new state s'.

    We can replace the SM pattern matching in the definition of >>= with runState.

     x >>= f = SM (\s -> let
                             (a, s')  = runState x     s
                             (b, s'') = runState (f a) s'
                         in  (b, s''))
    

    Now we'll go through it piece by piece.

    Bind constructs a new StateMonad with a function that depends on some initial state s. It returns a state-dependent value b and a new state s'':

     x >>= f = SM (\s -> let
                             ...
                         in  (b, s''))
    

    The state-dependent value a and a new state s' are computed by running the state monad x with the initial state s:

                         let
                             (a, s')  = runState x     s
    

    A new state monad f a is determined from the user-supplied function f and the state-dependent value a. This second state monad is run with the intermediate state s'. It computes another state-dependent value b and a final state s''.

                             (b, s'') = runState (f a) s'   
    

    The second state-dependent value b and the final state s'' are what's returned by the function constructed for the new StateMonad.