Search code examples
scalafunctional-programmingmonadsstate-monad

Pure functional Random number generator - State monad


The book 'Functional Programming in Scala' demonstrates an example of pure functional random number generator as below

trait RNG {
    def nextInt: (Int, RNG)
}

object RNG {
    def simple(seed: Long): RNG = new RNG {
        def nextInt = {
            val seed2 = (seed*0x5DEECE66DL + 0xBL) &
                        ((1L << 48) - 1)
            ((seed2 >>> 16).asInstanceOf[Int],
             simple(seed2))
        }
    }
}

The usage will look like

val (randomNumber,nextState) = rng.nextInt

I do get the part that it's a pure function as it returns the next state and leaves it on the API client to use it to call nextInt the next time it would need a random number but what I did not understand is 'how will the first random number be generated as we must provide seed at least once.

Should there be another function to lift seed to get a RNG? And if so then how do we expect the client of this API to know about it (because in the non-functional implementation user just calls nextInt and the state is maintained by API)

Can someone give a full example of pure functional random number generator in Scala and perhaps relate it to state Monad in general.


Solution

  • That random generator RNG is pure functional, for the same inputs you get always the same outputs. The non-pure-functional part is left for the user of that API (you).

    To use the RNG in a pure-functional way you have to initialize it always with the same initial value, but then you will always get the same sequence of numbers, which is not so useful.

    Otherwise, you will have to rely the initialization of RNG to an external system (usually the wall-clock time) and so introducing side effects (bye pure functional).

    val state0 = RNG.simple(System.currentTimeMillis)
    
    val (rnd1, state1) = state0.nextInt
    val (rnd2, state2) = state1.nextInt
    val (rnd3, state3) = state2.nextInt
    
    println(rnd1, rnd2, rnd3)
    

    [EDIT]

    Inspired by the answer of @Aivean, I created my version of randoms Stream:

    def randoms: Stream[Int] = Stream.from(0)
      .scanLeft((0, RNG.simple(System.currentTimeMillis)))((st, _) => st._2.nextInt)
      .tail
      .map(_._1)
    
    println(randoms.take(5).toList)
    println(randoms.filter(_ > 0).take(3).toList)