Search code examples
haskellmonadsstate-monad

How does this State monad code works?


This code is from this article

I've been able to follow it until this part.

module Test where

type State = Int

data ST a = S (State -> (a, State))

apply        :: ST a -> State -> (a,State)
apply (S f) x = f x

fresh =  S (\n -> (n, n+1))

instance Monad ST where
    -- return :: a -> ST a
    return x   = S (\s -> (x,s))

    -- (>>=)  :: ST a -> (a -> ST b) -> ST b
    st >>= f   = S (\s -> let (x,s') = apply st s in apply (f x) s')

data Tree a = Leaf a | Node (Tree a) (Tree a) deriving (Show)


mlabel  :: Tree a -> ST (Tree (a,Int))
-- THIS IS THE PART I DON'T UNDERSTAND:
mlabel (Leaf x) = do n <- fresh
                     return (Leaf (x,n))
mlabel (Node l r) =  do l' <- mlabel l
                        r' <- mlabel r
                        return (Node l' r')

label t = fst (apply (mlabel t) 0)

tree = Node (Node (Leaf 'a') (Leaf 'b')) (Leaf 'c')

And label tree produces:

Node (Node (Leaf ('a',0)) (Leaf ('b',1))) (Leaf ('c',2))

I can see that >>= operator is the tool to 'chain' functions that return monads (or something like that).

And while I think I understand this code, I don't understand how this particular code works.

Specifically do n <- fresh. We haven't passed any argument to fresh, right? What does n <- fresh produces in that case? Absolutely don't understand that. Maybe it has something to do with currying?


Solution

  • With the monadic "pipelining" inlined, your code becomes

    fresh state = (state, state + 1)
    
    mlabel (Leaf x) state =                   --  do
      let (n, state') = fresh state           --    n <- fresh
      in  (Leaf (x,n), state')                --    return (Leaf (x,n))
    
    mlabel (Node l r) state =                 -- do
      let (l', state') = mlabel l state       --    l' <- mlabel l
      in let (r', state'') = mlabel r state'  --    r' <- mlabel r
         in  (Node l' r', state'')            --    return (Node l' r') 
    
    main = let (result, state') = mlabel tree 0  
           in  print result                         
    
    {- Or with arrows,
    
    mlabel (Leaf x)   = Leaf . (x ,)  &&&  (+ 1)
    mlabel (Node l r) = mlabel l >>> second (mlabel r)
                                  >>> (\(a,(b,c)) -> (Node a b,c))
    main              = mlabel tree >>> fst >>> print  $ 0
    -}
    

    Or in an imperative pseudocode:

    def state = unassigned
    
    def fresh ():
        tmp = state 
        state := state + 1     -- `fresh` alters the global var `state`
        return tmp             -- each time it is called
    
    def mlabel (Leaf x):       -- a language with pattern matching
        n = fresh ()           -- global `state` is altered!
        return (Leaf (x,n))  
    
    def mlabel (Node l r):
        l' = mlabel l          -- affects the global
        r' = mlabel r          --    assignable variable
        return (Node l' r')    --    `state`
    
    def main:
        state := 0             -- set to 0 before the calculation!
        result = mlabel tree
        print result
    

    Calculating the result changes the state assignable; it corresponds to the snd field in Haskell's (a, State) tuple. And the fst field of the tuple is the newly constructed tree, carrying a numbering alongside its data in its leaves.

    These variants are functionally equivalent.

    Perhaps you've heard the catch-phrase about monadic bind being a "programmable semicolon". Here the meaning of it is clear: it defines the "function call protocol" so to speak, that we use the first returned value as a calculated result, and the second returned value as the updated state, which we pass along to the next calculation, so it gets to see the updated state.

    This is the state-passing style of programming (essential for e.g. Prolog), making the state change explicit but having to manually take care of passing along the correct, updated state. Monads allow us to abstract this "wiring" of state passing from one calculation to the next, so it is done automatically for us, at the price of having to think in imperative style, and having this state become hidden, implicit again (like the state change is implicit in the imperative programming, which we wanted to eschew in the first place when switching to the functional programming...).

    So all that the State monad is doing is to maintain for us this hidden state, and passing it along updated between the consecutive calculations. So it's nothing major after all.