Search code examples
scalahaskellmonadsscala-breeze

Implementing flatMap on aggregate monad


I'm looking to implement a version of Generator combinators (e.g. analogous to those in ScalaCheck or Haskell's QuickCheck), in which a Generator contains an instance of Rand, a monad representing a probability distribution (taken from the breeze library). Since it is a monad, Rand implements map and flatMap. As is common, I also want to implement Gen as a monad. As shown below, the implementation of map for Gen is straightforward:

// Rand is from the breeze library
trait Rand[T] {
    def map[U](f: T => U): Rand[U]
    def flatMap[U](f: T => Rand[U]): Rand[U]
}

case class Gen[T](dist: Rand[T]) {
  def map[U](f: T => U): Gen[U] = Gen(dist.map { f })

  def flatMap[U](f: T => Gen[U]): Gen[U] = {
    // How to implement this?
  }
}

However, it is not clear to me how flatMap should be implemented. Is this easily achieved, or does it (for example) require a level of indirection via some intermediate datatype?


Solution

  • A possible implementation could be

    def flatMap[U](f: T => Gen[U]): Gen[U] = 
      Gen (dist.flatMap {f(_).dist})