Search code examples
scalastatemonadsstate-monad

what is proper monad or sequence comprehension to both map and carry state across?


I'm writing a programming language interpreter.

I have need of the right code idiom to both evaluate a sequence of expressions to get a sequence of their values, and propagate state from one evaluator to the next to the next as the evaluations take place. I'd like a functional programming idiom for this.

It's not a fold because the results come out like a map. It's not a map because of the state prop across.

What I have is this code which I'm using to try to figure this out. Bear with a few lines of test rig first:

// test rig
class MonadLearning extends JUnit3Suite {

  val d = List("1", "2", "3") // some expressions to evaluate. 

  type ResType = Int 
  case class State(i : ResType) // trivial state for experiment purposes
  val initialState = State(0)

// my stub/dummy "eval" function...obviously the real one will be...real.
  def computeResultAndNewState(s : String, st : State) : (ResType, State) = {
    val State(i) = st
    val res = s.toInt + i
    val newStateInt = i + 1
    (res, State(newStateInt))
  }

My current solution. Uses a var which is updated as the body of the map is evaluated:

  def testTheVarWay() {
    var state = initialState
    val r = d.map {
      s =>
        {
          val (result, newState) = computeResultAndNewState(s, state)
          state = newState
          result
        }
    }
    println(r)
    println(state)
  }

I have what I consider unacceptable solutions using foldLeft which does what I call "bag it as you fold" idiom:

def testTheFoldWay() {

// This startFold thing, requires explicit type. That alone makes it muddy.
val startFold : (List[ResType], State) = (Nil, initialState)
val (r, state) = d.foldLeft(startFold) {
  case ((tail, st), s) => {
    val (r, ns) = computeResultAndNewState(s, st)
    (tail :+ r, ns) // we want a constant-time append here, not O(N). Or could Cons on front and reverse later
  }
}

println(r)
println(state)

}

I also have a couple of recursive variations (which are obvious, but also not clear or well motivated), one using streams which is almost tolerable:

def testTheStreamsWay() {
  lazy val states = initialState #:: resultStates // there are states
  lazy val args = d.toStream // there are arguments
  lazy val argPairs = args zip states // put them together
  lazy val resPairs : Stream[(ResType, State)] = argPairs.map{ case (d1, s1) => computeResultAndNewState(d1, s1) } // map across them
  lazy val (results , resultStates) = myUnzip(resPairs)// Note .unzip causes infinite loop. Had to write my own.

  lazy val r = results.toList
  lazy val finalState = resultStates.last

  println(r)
  println(finalState)
}

But, I can't figure out anything as compact or clear as the original 'var' solution above, which I'm willing to live with, but I think somebody who eats/drinks/sleeps monad idioms is going to just say ... use this... (Hopefully!)


Solution

  • With the map-with-accumulator combinator (the easy way)

    The higher-order function you want is mapAccumL. It's in Haskell's standard library, but for Scala you'll have to use something like Scalaz.

    First the imports (note that I'm using Scalaz 7 here; for previous versions you'd import Scalaz._):

    import scalaz._, syntax.std.list._
    

    And then it's a one-liner:

    scala> d.mapAccumLeft(initialState, computeResultAndNewState)
    res1: (State, List[ResType]) = (State(3),List(1, 3, 5))
    

    Note that I've had to reverse the order of your evaluator's arguments and the return value tuple to match the signatures expected by mapAccumLeft (state first in both cases).

    With the state monad (the slightly less easy way)

    As Petr Pudlák points out in another answer, you can also use the state monad to solve this problem. Scalaz actually provides a number of facilities that make working with the state monad much easier than the version in his answer suggests, and they won't fit in a comment, so I'm adding them here.

    First of all, Scalaz does provide a mapM—it's just called traverse (which is a little more general, as Petr Pudlák notes in his comment). So assuming we've got the following (I'm using Scalaz 7 again here):

    import scalaz._, Scalaz._
    
    type ResType = Int
    case class Container(i: ResType)
    
    val initial = Container(0)
    val d = List("1", "2", "3")
    
    def compute(s: String): State[Container, ResType] = State {
      case Container(i) => (Container(i + 1), s.toInt + i)
    }
    

    We can write this:

    d.traverse[({type L[X] = State[Container, X]})#L, ResType](compute).run(initial)
    

    If you don't like the ugly type lambda, you can get rid of it like this:

    type ContainerState[X] = State[Container, X]
    
    d.traverse[ContainerState, ResType](compute).run(initial)
    

    But it gets even better! Scalaz 7 gives you a version of traverse that's specialized for the state monad:

    scala> d.traverseS(compute).run(initial)
    res2: (Container, List[ResType]) = (Container(3),List(1, 3, 5))
    

    And as if that wasn't enough, there's even a version with the run built in:

    scala> d.runTraverseS(initial)(compute)
    res3: (Container, List[ResType]) = (Container(3),List(1, 3, 5))
    

    Still not as nice as the mapAccumLeft version, in my opinion, but pretty clean.