Search code examples
scalamonadsflatmap

Monads - Purpose of Flatten


So I've been doing some reading about Monads in scala, and all the syntax relating to their flatMap functions and for comprehensions. Intuitively, I understand why Monads need to use the map part of the flatMap function, as usually when I use a map function its on a container of zero, one or many elements and returns the same container but with the passed in function applied to all the elements of the container. A Monad similarly is a container of zero, one or many elements.

However, what is the purpose of the flatten part of the flatMap? I cannot understand the intuition behind it. To me it seems like extra boilerplate that requires all the functions passed to flatMap to create a container / monad around their return value only to have that container instantly destroyed by the flatten part of the flatMap. I cannot think of a single example where flatMap is used which can't be simplified by being replaced simply by map. For example:

var monad = Option(5)
var monad2 = None

def flatAdder(i:Int) = Option(i + 1)
def adder(i:Int) = i + 1

// What I see in all the examples
monad.flatMap(flatAdder)
// Option[Int] = Some(6)
monad.flatMap(flatAdder).flatMap(flatAdder)
// Option[Int] = Some(7)
monad2.flatMap(flatAdder)
// Option[Int] = None
monad2.flatMap(flatAdder).flatMap(flatAdder)
// Option[Int] = None

// Isn't this a lot easier?
monad.map(adder)
// Option[Int] = Some(6)
monad.map(adder).map(adder)
// Option[Int] = Some(7)
monad2.map(adder)
// Option[Int] = None
monad2.map(adder).map(adder)
// Option[Int] = None

To me, using map by itself seems far more intuitive and simple than flatMap and the flatten part does not seem to add any sort of value. However, in Scala a large emphasis is placed on flatMap rather than map, to the point it even gets its own syntax for for comprehensions, so clearly I must be missing something. My question is: in what cases is the flatten part of flatMap actually useful? and what other advantages does flatMap have over map?


Solution

  • What if each element being processed could result in zero or more elements?

    List(4, 0, 15).flatMap(n => primeFactors(n))
    //List(2, 2, 3, 5)
    

    What if you have a name, that might not be spelled correctly, and you want their office assignment, if they have one?

    def getID(name:String): Option[EmployeeID] = ???
    def getOffice(id:EmployeeID): Option[Office] = ???
    
    val office :Option[Office] =
      getID(nameAttempt).flatMap(id => getOffice(id))
    

    You use flatMap() when you have a monad and you need to feed its contents to a monad producer. You want the result Monad[X] not Monad[Monad[X]].