Search code examples
scalastate-monad

Scala and State Monad


I have been trying to understand the State Monad. Not so much how it is used, though that is not always easy to find, either. But every discussion I find of the State Monad has basically the same information and there is always something I don't understand.

Take this post, for example. In it the author has the following:

case class State[S, A](run: S => (A, S)) {
...
  def flatMap[B](f: A => State[S, B]): State[S, B] =
    State(s => {
      val (a, t) = run(s)
      f(a) run t
    })
...
}

I can see that the types line up correctly. However, I don't understand the second run at all.

Perhaps I am looking at the whole purpose of this monad incorrectly. I got the impression from the HaskellWiki that the State monad was kind of like a state-machine with the run allowing for transitions (though, in this case, the state-machine doesn't really have fixed state transitions like most state machines). If that is the case then in the above code (a, t) would represent a single transition. The application of f would represent a modification of that value and State (generating a new State object). That leaves me completely confused as to what the second run is all about. It would appear to be a second 'transition'. But that doesn't make any sense to me.

I can see that calling run on the resulting State object produces a new (A, S) pair which, of course, is required for the types to line up. But I don't really see what this is supposed to be doing.

So, what is really going on here? What is the concept being modeled here?

Edit: 12/22/2015

So, it appears I am not expressing my issue very well. Let me try this.

In the same blog post we see the following code for map:

def map[B](f: A => B): State[S, B] =
  State(s => {
    val (a, t) = run(s)
    (f(a), t)
  })

Obviously there is only a single call to run here.

The model I have been trying to reconcile is that a call to run moves the state we are keeping forward by a single state-change. This seems to be the case in map. However, in flatMap we have two calls to run. If my model was correct that would result in 'skipping over' a state change.

To make use of the example @Filppo provided below, the first call to run would result in returning (1, List(2,3,4,5)) and the second would result in (2, List(3,4,5)), effectively skipping over the first one. Since, in his example, this was followed immediately by a call to map, this would have resulted in (Map(a->2, b->3), List(4,5)).

Apparently that is not what is happening. So my whole model is incorrect. What is the correct way to reason about this?

2nd Edit: 12/22/2015

I just tried doing what I said in the REPL. And my instincts were correct which leaves me even more confused.

scala> val v = State(head[Int]).flatMap { a => State(head[Int]) }
v: State[List[Int],Int] = State(<function1>

scala> v.run(List(1,2,3,4,5))
res2: (Int, List[Int]) = (2,List(3, 4, 5))

So, this implementation of flatMap does skip over a state. Yet when I run @Filippo's example I get the same answer he does. What is really happening here?


Solution

  • To understand the "second run" let's analyse it "backwards".

    The signature def flatMap[B](f: A => State[S, B]): State[S, B] suggests that we need to run a function f and return its result.

    To execute function f we need to give it an A. Where do we get one?
    Well, we have run that can give us A out of S, so we need an S.

    Because of that we do: s => val (a, t) = run(s) .... We read it as "given an S execute the run function which produces us A and a new S. And this is our "first" run.

    Now we have an A and we can execute f. That's what we wanted and f(a) gives us a new State[S, B]. If we do that then we have a function which takes S and returns Stats[S, B]:

    (s: S) => 
       val (a, t) = run(s)
       f(a) //State[S, B]
    

    But function S => State[S, B] isn't what we want to return! We want to return just State[S, B].

    How do we do that? We can wrap this function into State:

    State(s => ... f(a))
    

    But it doesn't work because State takes S => (B, S), not S => State[B, S]. So we need to get (B, S) out of State[B, S].
    We do it by just calling its run method and providing it with the state we just produced on the previous step! And it is our "second" run.

    So as a result we have the following transformation performed by a flatMap:

    s =>                   // when a state is provided
      val (a, t) = run(s)  // produce an `A` and a new state value
      val resState = f(a)  // produce a new `State[S, B]`
      resState.run(t)      // return `(S, B)`
    

    This gives us S => (S, B) and we just wrap it with the State constructor.

    Another way of looking at these "two runs" is:
    first - we transform the state ourselves with "our" run function
    second - we pass that transformed state to the function f and let it do its own transformation.

    So we kind of "chaining" state transformations one after another. And that's exactly what monads do: they provide us with the ability to schedule computation sequentially.