Search code examples
scalacollectionsflatmap

Scala - access collection members within map or flatMap


Suppose that I use a sequence of various maps and/or flatMaps to generate a sequence of collections. Is it possible to access information about the "current" collection from within any of those methods? For example, without knowing anything specific about the functions used in the previous maps or flatMaps, and without using any intermediate declarations, how can I get the maximum value (or length, or first element, etc.) of the collection upon which the last map acts?

List(1, 2, 3)
  .flatMap(x => f(x) /* some unknown function */)
  .map(x => x + ??? /* what is the max element of the collection? */)

Edit for clarification:

  1. In the example, I'm not looking for the max (or whatever) of the initial List. I'm looking for the max of the collection after the flatMap has been applied.

  2. By "without using any intermediate declarations" I mean that I do not want to use any temporary collections en route to the final result. So, the example by Steve Waldman below, while giving the desired result, is not what I am seeking. (I include this condition is mostly for aesthetic reasons.)

Edit for clarification, part 2:

The ideal solution would be some magic keyword or syntactic sugar that lets me reference the current collection:

List(1, 2, 3)
  .flatMap(x => f(x))
  .map(x => x + theCurrentList.max)

I'm prepared to accept the fact, however, that this simply is not possible.


Solution

  • You could define a mapWithSelf (resp. flatMapWithSelf) operation along these lines and add it as an implicit enrichment to the collection. For List it might look like:

    // Scala 2.13 APIs
    object Enrichments {
      implicit class WithSelfOps[A](val lst: List[A]) extends AnyVal {
        def mapWithSelf[B](f: (A, List[A]) => B): List[B] =
          lst.map(f(_, lst))
    
        def flatMapWithSelf[B](f: (A, List[A]) => IterableOnce[B]): List[B] =
          lst.flatMap(f(_, lst))
      }
    }
    

    The enrichment basically fixes the value of the collection before the operation and threads it through. It should be possible to generify this (at least for the strict collections), though it would look a little different in 2.12 vs. 2.13+.

    Usage would look like

    import Enrichments._
    
    val someF: Int => IterableOnce[Int] = ???
    
    List(1, 2, 3)
      .flatMap(someF)
      .mapWithSelf { (x, lst) =>
        x + lst.max
      }
    

    So at the usage site, it's aesthetically pleasant. Note that if you're computing something which traverses the list, you'll be traversing the list every time (leading to a quadratic runtime). You can get around that with some mutability or by just saving the intermediate list after the flatMap.