Search code examples
f#state-monad

How does the state monad bind into outer context


I am trying to understand the State monad and I am mighty confused I must admit. I created a Computational Expression and added a lots of print statements so I can follow up who gets called when.

type State<'st,'a> =
    | Ok of  'a * 'st
    | Error of string
and StateMonadBuilder() =
    member b.Return(x) = printfn "100 Return %A" x; fun s -> Ok (x, s)
    member b.ReturnFrom(x) = printfn "100 ReturnFrom %A" x; x
    member b.Bind(p, rest) =
        printfn "100 Bind:: %A %A" p rest
        fun state ->
            printfn "200 Bind:: %A %A" p rest
            let result = p state in
            match result with
            | Ok (value,state2) -> (rest value) state2
            | Error msg -> Error msg  

    member b.Get () = 
        printfn "100 Get"
        fun state -> 
            printfn "200 Get :: %A" state
            Ok (state, state)
    member b.Put s = fun state -> Ok ((), s)

let state = StateMonadBuilder()

let turn () =
    state {
        printfn "100 turn::"
        let! pos1 = state.Get()
        printfn "200 turn:: %A" pos1
        let! pos2 = state.Get()
        printfn "300 turn:: %A" pos1
        return! state.Put(fst pos1, snd pos1 - 1)
    }

let move () =
    state {
        printfn "100 move::"
        let! x = turn()
        printfn "200 move:: %A" x
        let! y = turn()
        printfn "200 move:: %A" y
        return x
    }

let run () =
    state {
        printfn "100 run::"
        do! move()
    }

run () (5,5) |> ignore

The above code prints the following output

100 run::
100 move::
100 turn::
100 Get
100 Bind:: <fun:Get@301> <fun:turn@312>
100 Bind:: <fun:Bind@292-2> <fun:move@322>
100 Bind:: <fun:Bind@292-2> <fun:run@329>
200 Bind:: <fun:Bind@292-2> <fun:run@329>
200 Bind:: <fun:Bind@292-2> <fun:move@322>
200 Bind:: <fun:Get@301> <fun:turn@312>
200 Get :: (5, 5)
200 turn:: (5, 5)
100 Get
100 Bind:: <fun:Get@301> <fun:turn@314-1>
200 Bind:: <fun:Get@301> <fun:turn@314-1>
200 Get :: (5, 5)
300 turn:: (5, 5)
100 ReturnFrom <fun:Put@304>
200 move:: <null>
100 Return <null>
100 Return <null>

I understand the first 5 lines of this output. Obviously run calls move calls turn calls Get. And then there is a let! pos1 = ... which triggers the call to Bind. So far so good. But then there are additional calls to Bind. How do they come into existence? I understand on a superficial level that binding into those outer contexts must be somehow the magic of the state monad but how does this mechanism work? And then there is another let! pos2 = ... in the function turn which also triggers Bind but this time only once not 3 times as before!

Looking forward to your explainations


Solution

  • There's no magic involved, all smoke and mirrors.

    The computation you build up in your workflow is one big function of type 'st -> State<'st, 'a>. The place where you call run is in fact where you apply this function to an initial argument - that's what is passed through the binds and in turn, from the "parent" move workflow to turn. So it's not that your nested workflow is accessing anything outside - you pass it there yourself instead.

    One non-standard choice that you make - that probably doesn't make it easier to understand what is going on - is that your State monad is not a pure state monad, rather it combines aspects of State and Either/Maybe monads (through the Error case of State type). And while you define the State type, your actual monadic type here is the function type I mentioned earlier.

    A typical approach would be to define the type as something like this:

    type State<'st, 'a> = State of ('st ->'a * 'st),
    

    i.e. you use a single case union as a wrapper for a function type, or just use the function without the wrapping type. Error handling is typically not a concern state monad deals with.

    As for the second part of the question, you do have three binds on your path - do! move(), let! x = turn() and let! pos1 = state.Get() - and this is what you see in the log. I think the sequence in which things happen might be tricky here.

    Remember how a bind is desugared:

    {| let! pattern = expr in cexpr |} => builder.Bind(expr, (fun pattern -> {| cexpr |}))
    

    What this means is first you evaluate expr, only then Bind is called, and finally the rest of computation cexpr. In your case, you go "three binds deep" to evaluate the first expr - which is the call to Get() - then you start resolving your stack of binds, at some point calling another bind as part of the cexpr.

    It would probably be easier to see what's really going on if you added another print statement after computing let result = p state in Bind, this is when the bind is being unwound.