Despite reading through the really clear explanation in LYAH, and then Haskell Wiki, and some other stuff, I am still confused about how the state monad is implemented. I think I understand what it is though I'm not confident.
So let's say I have some trivial data type:
data Simple a = Top a
deriving ( Show )
And this:
newtype SimpleState a = SimpleState { applySimple :: Int -> ( a, Int ) }
I then make SimpleState a monad
instance Monad SimpleState where
return x = SimpleState $ \s -> ( x, s )
st >>= g = SimpleState $ \s -> let ( x, s' ) = applySimple st s in applySimple ( g x ) s'
Question 1: How is the lambda taking in s ( for state ) as a parameter? How is it passed in?
Question 2: if applySimple
is taking in one parameter in its function signature, why do I have applySimple st s
inside the lambda? Why is applySimple
applied twice?
Even more confusing, this thing changes the state:
tic :: SimpleState Int
tic = SimpleState $ \s -> ( s, s + 1 )
Question 3. What is this? Why is it doing some sort of action on the SimpleState but its signature is not a function?
So now I could pass tic into this function:
incr :: Simple a -> SimpleState ( Simple ( a, Int ) )
incr ( Top a ) = do
v <- tic
return ( Top ( a, v ) )
Question 4: could I / how would I use tic with >>=
?
And by using it like so:
applySimple ( incr ( Top 1 ) ) 3
I get this:
(Top (1,3),4)
Again, applySimple
is applied to two params, which confuses me.
In summary, I'm getting really hung up on the fact that the constructor SimpleState
is taking in a function that takes in in s as a param, and have no idea where it's coming from how it's used in context.
Question 1: How is the lambda taking in s ( for state ) as a parameter? How is it passed in?
Let's use the classic definitions of get
and put
, defined as:
put :: Int -> SimpleState ()
put n = SimpleState (\_ -> ((), n))
get :: SimpleState Int
get = SimpleState (\s -> (s, s))
When you call applySimple
, you unwrap the SimpleState
, which exposes a function of type Int -> (a, Int)
. Then you apply that function to your initial state. Let's try it out using some concrete examples.
First, we'll run the command put 1
, with an initial state of 0
:
applySimple (put 1) 0
-- Substitute in definition of 'put'
= applySimple (SimpleState (\_ -> ((), 1))) 0
-- applySimple (Simple f) = f
(\_ -> ((), 1)) 0
-- Apply the function
= ((), 1)
Notice how put
ignores the initial state and just replaces the right state slot with 1
, leaving behind ()
in the left return value slot.
Now let's try running the get command, using a starting state of 0
:
applySimple get 0
-- Substitute in definition of 'get'
= applySimple (SimpleState (\s -> (s, s))) 0
-- applySimple (SimpleState f) = f
= (\s -> (s, s)) 0
-- Apply the function
= (0, 0)
get
just copies 0
into the left return value slot, leaving the right state slot unchanged.
So the way you pass your initial state into that lambda function is just by unwrapping the SimpleState
newtype to expose the underlying lambda function and directly applying the lambda function to the initial state.
Question 2: if applySimple is taking in one parameter in its function signature, why do I have applySimple st s inside the lambda? Why is applySimpleapplied twice?
That's because applySimple
's type is not Int -> (a, Int)
. It's actually:
applySimple :: SimpleState -> Int -> (a, Int)
This is a confusing aspect of Haskell's record syntax. Whenever you have a record field like:
data SomeType { field :: FieldType }
... then field
's type is actually:
field :: SomeType -> FieldType
I know it's weird.
Question 3. What is this? Why is it doing some sort of action on the SimpleState but its signature is not a function?
The SimpleState
newtype hides the function that it wraps. newtypes
can hide absolutely anything until you unwrap them. Your SimpleState
does have function inside of it, but all the compiler sees is the outer newtype until you unwrap it with applySimple
.
Question 4: could I / how would I use tic with >>= ?
You're making a mistake in your definition of incr
. The correct way to use tic
would be like this:
ticTwice :: SimpleState ()
ticTwice = do
tic
tic
... which the compiler translates to:
ticTwice =
tic >>= \_ ->
tic
Using your definition of (>>=)
and tic, you can prove that this increments the value by two:
tic >>= \_ -> tic
-- Substitute in definition of `(>>=)`
SimpleState (\s ->
let (x, s') = applySimple tic s
in applySimple ((\_ -> tic) x) s')
-- Apply the (\_ -> tic) function
SimpleState (\s ->
let (x, s') = applySimple tic s
in applySimple tic s')
-- Substitute in definition of `tic`
SimpleState (\s ->
let (x, s') = applySimple (SimpleState (\s -> (s, s + 1))) s
in applySimple (SimpleState (\s -> (s, s + 1))) s'
-- applySimple (SimpleState f) = f
SimpleState (\s ->
let (x, s') = (\s -> (s, s + 1)) s
in (\s -> (s, s + 1)) s'
-- Apply both functions
SimpleState (\s ->
let (x, s') = (s, s + 1)
in (s', s' + 1)
-- Simplify by getting rid of unused 'x'
SimpleState (\s ->
let s' = s + 1
in (s', s' + 1)
-- Simplify some more:
SimpleState (\s -> s + 1, s + 2)
So you see that when you chain two tic
s using (>>=)
, it combines them into a single stateful function that increments the state by 2
, and returns the state plus 1
.