Search code examples
scalascalazscala-catsstate-monadscalaz7

Chaining a number of transitions with the state Monad


I am starting to use the state monad to clean up my code. I have got it working for my problem where I process a transaction called CDR and modify the state accordingly. It is working perfectly fine for individual transactions, using this function to perform the state update.

def addTraffic(cdr: CDR): Network => Network = ...

Here is an example:

scala> val processed: (CDR) => State[Network, Long] = cdr =>
 |   for {
 |     m <- init
 |     _ <- modify(Network.addTraffic(cdr))
 |     p <- get
 |   } yield p.count
processed: CDR => scalaz.State[Network,Long] = $$Lambda$4372/1833836780@1258d5c0

scala> val r = processed(("122","celda 1", 3))
r: scalaz.State[Network,Long] = scalaz.IndexedStateT$$anon$13@4cc4bdde

scala> r.run(Network.empty)
res56: scalaz.Id.Id[(Network, Long)] = (Network(Map(122 -> (1,0.0)),Map(celda 1 -> (1,0.0)),Map(1 -> Map(1 -> 3)),1,true),1)

What i want to do now is to chain a number of transactions on an iterator. I have found something that works quite well but the state transitions take no inputs (state changes through RNG)

  import scalaz._
  import scalaz.std.list.listInstance
  type RNG = scala.util.Random

  val f = (rng:RNG) => (rng, rng.nextInt)
  val intGenerator: State[RNG, Int] = State(f)
  val rng42 = new scala.util.Random
  val applicative = Applicative[({type l[Int] = State[RNG,Int]})#l]

  // To generate the first 5 Random integers
  val chain: State[RNG, List[Int]] = applicative.sequence(List.fill(5)(intGenerator))
  val chainResult: (RNG, List[Int]) = chain.run(rng42)
  chainResult._2.foreach(println)

I have unsuccessfully tried to adapt this, but I can not get they types signatures to match because my state function requires the cdr (transaction) input

Thanks


Solution

  • TL;DR
    you can use traverse from the Traverse type-class on a collection (e.g. List) of CDRs, using a function with this signature: CDR => State[Network, Long]. The result will be a State[Network, List[Long]]. Alternatively, if you don't care about the List[Long] there, you can use traverse_ instead, which will return State[Network, Unit]. Finally, should you want to "aggregate" the results T as they come along, and T forms a Monoid, you can use foldMap from Foldable, which will return State[Network, T], where T is the combined (e.g. folded) result of all Ts in your chain.

    A code example
    Now some more details, with code examples. I will answer this using Cats State rather than Scalaz, as I never used the latter, but the concept is the same and, if you still have problems, I will dig out the correct syntax.

    Assume that we have the following data types and imports to work with:

    import cats.implicits._
    import cats.data.State
    
    case class Position(x : Int = 0, y : Int = 0)
    
    sealed trait Move extends Product
    case object Up extends Move
    case object Down extends Move
    case object Left extends Move
    case object Right extends Move
    

    As it is clear, the Position represents a point in a 2D plane and a Move can move such point up, down, left or right.

    Now, lets create a method that will allow us to see where we are at a given time:

    def whereAmI : State[Position, String] = State.inspect{ s => s.toString }
    

    and a method to change our position, given a Move:

    def move(m : Move) : State[Position, String] = State{ s => 
      m match {
        case Up => (s.copy(y = s.y + 1), "Up!")
        case Down => (s.copy(y = s.y - 1), "Down!")
        case Left => (s.copy(x = s.x - 1), "Left!")
        case Right => (s.copy(x = s.x + 1), "Right!")
      }
    }
    

    Notice that this will return a String, with the name of the move followed by an exclamation mark. This is just to simulate the type change from Move to something else, and show how the results will be aggregated. More on this in a bit.

    Now let's try to play with our methods:

    val positions : State[Position, List[String]] = for{
      pos1 <- whereAmI 
      _ <- move(Up)
      _ <- move(Right)
      _ <- move(Up)
      pos2 <- whereAmI
      _ <- move(Left)
      _ <- move(Left)
      pos3 <- whereAmI
    } yield List(pos1,pos2,pos3)
    

    And we can feed it an initial Position and see the result:

    positions.runA(Position()).value // List(Position(0,0), Position(1,2), Position(-1,2))
    

    (you can ignore the .value there, it's a quirk due to the fact that State[S,A] is really just an alias for StateT[Eval,S,A])

    As you can see, this behaves as you would expect, and you can create different "blueprints" (e.g. sequences of state modifications), which will be applied once an initial state is provided.

    Now, to actually answer to you question, say we have a List[Move] and we want to apply them sequentially to an initial state, and get the result: we use traverse from the Traverse type-class.

    val moves = List(Down, Down, Left, Up)
    val result : State[Position, List[String]] = moves.traverse(move)
    result.run(Position()).value // (Position(-1,-1),List(Down!, Down!, Left!, Up!))
    

    Alternatively, should you not need the A at all (the List in you case), you can use traverse_, instead of traverse and the result type will be:

    val result_ : State[Position, List[String]] = moves.traverse_(move)
    result_.run(Position()).value // (Position(-1,-1),Unit)
    

    Finally, if your A type in State[S,A] forms a Monoid, then you could also use foldMap from Foldable to combine (e.g. fold) all As as they are calculated. A trivial example (probably useless, because this will just concatenate all Strings) would be this:

    val result : State[Position,String] = moves.foldMap(move)
    result.run(Position()).value // (Position(-1,-1),Down!Down!Left!Up!)
    

    Whether this final approach is useful or not to you, really depends on what A you have and if it makes sense to combine it.

    And this should be all you need in your scenario.