This code is from this article
I've been able to follow it until this part.
module Test where
type State = Int
data ST a = S (State -> (a, State))
apply :: ST a -> State -> (a,State)
apply (S f) x = f x
fresh = S (\n -> (n, n+1))
instance Monad ST where
-- return :: a -> ST a
return x = S (\s -> (x,s))
-- (>>=) :: ST a -> (a -> ST b) -> ST b
st >>= f = S (\s -> let (x,s') = apply st s in apply (f x) s')
data Tree a = Leaf a | Node (Tree a) (Tree a) deriving (Show)
mlabel :: Tree a -> ST (Tree (a,Int))
-- THIS IS THE PART I DON'T UNDERSTAND:
mlabel (Leaf x) = do n <- fresh
return (Leaf (x,n))
mlabel (Node l r) = do l' <- mlabel l
r' <- mlabel r
return (Node l' r')
label t = fst (apply (mlabel t) 0)
tree = Node (Node (Leaf 'a') (Leaf 'b')) (Leaf 'c')
And label tree
produces:
Node (Node (Leaf ('a',0)) (Leaf ('b',1))) (Leaf ('c',2))
I can see that >>=
operator is the tool to 'chain' functions that return monads (or something like that).
And while I think I understand this code, I don't understand how this particular code works.
Specifically do n <- fresh
. We haven't passed any argument to fresh, right? What does n <- fresh
produces in that case? Absolutely don't understand that. Maybe it has something to do with currying?
With the monadic "pipelining" inlined, your code becomes
fresh state = (state, state + 1)
mlabel (Leaf x) state = -- do
let (n, state') = fresh state -- n <- fresh
in (Leaf (x,n), state') -- return (Leaf (x,n))
mlabel (Node l r) state = -- do
let (l', state') = mlabel l state -- l' <- mlabel l
in let (r', state'') = mlabel r state' -- r' <- mlabel r
in (Node l' r', state'') -- return (Node l' r')
main = let (result, state') = mlabel tree 0
in print result
{- Or with arrows,
mlabel (Leaf x) = Leaf . (x ,) &&& (+ 1)
mlabel (Node l r) = mlabel l >>> second (mlabel r)
>>> (\(a,(b,c)) -> (Node a b,c))
main = mlabel tree >>> fst >>> print $ 0
-}
Or in an imperative pseudocode:
def state = unassigned
def fresh ():
tmp = state
state := state + 1 -- `fresh` alters the global var `state`
return tmp -- each time it is called
def mlabel (Leaf x): -- a language with pattern matching
n = fresh () -- global `state` is altered!
return (Leaf (x,n))
def mlabel (Node l r):
l' = mlabel l -- affects the global
r' = mlabel r -- assignable variable
return (Node l' r') -- `state`
def main:
state := 0 -- set to 0 before the calculation!
result = mlabel tree
print result
Calculating the result
changes the state
assignable; it corresponds to the snd
field in Haskell's (a, State)
tuple. And the fst
field of the tuple is the newly constructed tree, carrying a numbering alongside its data in its leaves.
These variants are functionally equivalent.
Perhaps you've heard the catch-phrase about monadic bind being a "programmable semicolon". Here the meaning of it is clear: it defines the "function call protocol" so to speak, that we use the first returned value as a calculated result, and the second returned value as the updated state, which we pass along to the next calculation, so it gets to see the updated state.
This is the state-passing style of programming (essential for e.g. Prolog), making the state change explicit but having to manually take care of passing along the correct, updated state. Monads allow us to abstract this "wiring" of state passing from one calculation to the next, so it is done automatically for us, at the price of having to think in imperative style, and having this state become hidden, implicit again (like the state change is implicit in the imperative programming, which we wanted to eschew in the first place when switching to the functional programming...).
So all that the State monad is doing is to maintain for us this hidden state, and passing it along updated between the consecutive calculations. So it's nothing major after all.