Search code examples
sqlscalafor-comprehension

Implement DISTINCT in for comprehension


In the penultimate lecture of his Coursera course, Prof. Odersky offered the following for comprehension as the final step in a lovely case study:

def solutions(target: Int): Stream[Path] =
  for {
    pathSet <- pathSets
    path <- pathSet
    if path.endState contains target
  } yield path

In an earlier lecture he drew some analogies between for comprehensions and SQL.

What I'm looking for is a way to yield only those paths that have a DISTINCT endState.

Is there a way to refer back from within a filter clause of the same comprehension to the items that have already been yielded?

Another approach might be to convert pathSets to a Map from endState to path before the for statement, then convert it back to a Stream before returning it. However, this would seem to lose the lazy computation benefits of using a Stream.

An earlier method from the same case study accomplished similar goals, but it was already a recursive function, while this one doesn't (seem to) need to be recursive.

It looks like I could use a mutable Set to track the endStates that get yielded, but that feels unsatisfying, since the course has successfully avoided using mutability so far.


Solution

  • Is there a way to refer back from within a filter clause of the same comprehension to the items that have already been yielded?

    Your for comprehension desugars to something more or less like

    pathSets flatMap {
      pathSet => pathSet filter {
       path => path.endState contains target
      }
    } map {path => path}
    

    The last map with an identity function is your yield. I can't remember if the spec allows that map to be elided when it's an identity function.

    Anyway, I hope this shows more clearly why there's no "reaching back" with that structure.

    You can write a lazy, recursive distinctBy function

    implicit class DistinctStream[T](s: Stream[T]) {
      def distinctBy[V](f: T => V): Stream[T] = {
        def distinctBy(remainder: Stream[T], seen:Set[V]): Stream[T] =
          remainder match {
            case head #:: tail => 
              val value = f(head)
              if (seen contains value) distinctBy(tail, seen)
              else Stream.cons(head, distinctBy(tail, seen + value))
            case empty => empty
         }
    
        distinctBy(s, Set())  
      }
    }
    

    And use it like so

    def solutions(target: Int): Stream[Path] =
    (for {
     pathSet <- pathSets
     path <- pathSet
     if path.endState contains target
    } yield path) distinctBy (_.endState)
    

    Yeah, now there's recursion. But there already was because Stream's map, flatMap, and filter functions are all lazy recursive functions already.