Search code examples
scalamonadsdepth-first-searchscala-catscats-effect

DFS Graph Traversal using Cats-Effect


I am trying to implement Graph DFS using cats-effect. At some point in the function you have to iterate over all neighbors and call recursively call DSF on them. [neighborSet.foreach(neighbor => ... DFS(neighbor))] However, when I naively convert the code, the recursive calls to DFS are not run. How can I solve this problem? My code is below. d is the pre-order, f is the post-order and pred is the list of nodes that are returned.

import cats.effect._
object DFS extends IOApp.Simple:
    def DFS[T](graph: Map[T, Set[T]], startNode: T): IO[List[(T, Int, Int)]] =
        val pred = Ref.of[IO, List[T]](List(startNode))
        val d = Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
        val f = Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
        val visitedIO = Ref.of[IO, Set[T]](Set.empty[T])
        val time = Ref.of[IO, Int](0)
        def neighborIteration(
            node: T,
            visitedRef: Ref[IO, Set[T]],
            predRef: Ref[IO, List[T]]
        ): IO[Unit] = IO.defer {
            for
                neighbors <- IO.delay(graph.getOrElse(node, Set()))
                visited <- visitedRef.updateAndGet(s => s + node)
                setIOUnit <- IO.delay(neighbors.map(neighbor => 
                    for {
                        _ <- IO.delay(println("Here is the problem"))
                        _ <- predRef.update(lt => neighbor :: lt)
                        _ <- DFS0(neighbor)
                    } yield ()
                ))
            yield ()

        }
        def DFS0(node: T): IO[Unit] =
            for
                timeRef <- time
                dRef <- d
                fRef <- f
                predRef <- pred
                visitedRef <- visitedIO
                nextTime <- timeRef.getAndUpdate(i => i + 1)
                _ <- dRef.update(m => m + (node -> nextTime))
                _ <- predRef.update(lt => node :: lt)
                visited <- visitedRef.get
                _ <-
                    if (!(visited contains node)) neighborIteration(node,visitedRef,predRef)
                    else IO(())
                nextTime2 <- timeRef.getAndUpdate(_ + 1)
                _ <- dRef.update(m => m + (node -> nextTime2))
            yield ()
        val value = for
            _ <- DFS0(startNode)
            predRef <- pred
            dRef <- d
            fRef <- f
            predVal <- predRef.get
            dVal <- dRef.get
            fVal <- fRef.get
        yield (predVal, dVal, fVal)
        val result = value.map { (lt, dval, fval) =>
            lt.map(e => (e, dval.getOrElse(e,0), fval.getOrElse(e,0)))
        }
        result
     override def run: IO[Unit] =
        val graph2 = Map(
          1 -> Set(2, 3, 4),
          2 -> Set(1),
          3 -> Set(1, 4),
          4 -> Set(1, 3, 7),
          5 -> Set(6),
          6 -> Set(5),
          7 -> Set(4, 8),
          8 -> Set(3, 7)
        )
        for
            dfsResult <- DFS(graph2, 1)
            _ <- IO(println(dfsResult))
        yield ()

UPDATE:

Based on a comment below I cleaned up the code and used sequence to turn List[IO[Unit]] into IO[Unit]. The below code now iterates over the neighbors of the starting node however not over the starting nodes neighbors.

import cats.effect._
import cats._
import cats.data._
import cats.syntax.all._
object DFS extends IOApp.Simple:
    def DFS[T](graph: Map[T, Set[T]], startNode: T): IO[List[(T, Int, Int)]] =
        def neighborIteration(
            neighbors: Set[T],
            visitedRef: Ref[IO, Set[T]],
            predRef: Ref[IO, List[T]],
            dRef: Ref[IO, Map[T, Int]],
            fRef: Ref[IO, Map[T, Int]],
            timeRef: Ref[IO, Int]
        ): IO[Unit] = neighbors.toList.map(e => 
            for 
                visited <- visitedRef.updateAndGet(s => s + e)
                _ <- IO(println(neighbors))
                _ <- DFS0(e, visitedRef, predRef, dRef, fRef, timeRef)
            yield ()
        ).sequence.void
        def DFS0(
            node: T,
            visitedRef: Ref[IO, Set[T]],
            predRef: Ref[IO, List[T]],
            dRef: Ref[IO, Map[T, Int]],
            fRef: Ref[IO, Map[T, Int]],
            timeRef: Ref[IO, Int]
        ): IO[Unit] =
            for
                nextTime <- timeRef.getAndUpdate(i => i + 1)
                _ <- dRef.update(m => m + (node -> nextTime))
                _ <- predRef.update(lt => node :: lt)
                visited <- visitedRef.get
                _ <- IO(println(node))
                _ <-  if (!(visited contains node))
                        neighborIteration(
                            graph.getOrElse(node, Set()),
                            visitedRef,
                            predRef,
                            dRef,
                            fRef,
                            timeRef
                          ) else IO(())
                nextTime2 <- timeRef.getAndUpdate(_ + 1)
                _ <- fRef.update(m => m + (node -> nextTime2))
            yield ()
        val value = for
            predRef <- Ref.of[IO, List[T]](List.empty[T])
            dRef <- Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
            fRef <- Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
            visitedRef <- Ref.of[IO, Set[T]](Set.empty[T])
            timeRef <- Ref.of[IO, Int](0)
            _ <- DFS0(startNode, visitedRef, predRef, dRef, fRef, timeRef)
            predVal <- predRef.get
            dVal <- dRef.get
            fVal <- fRef.get
        yield (predVal, dVal, fVal)
        val result = value.map { (lt, dval, fval) =>
            lt.map(e => (e, dval.getOrElse(e, 0), fval.getOrElse(e, 0)))
        }
        result

override def run: IO[Unit] =
    val graph2 = Map(
      1 -> Set(2, 3, 4),
      2 -> Set(1),
      3 -> Set(1, 4),
      4 -> Set(1, 3, 7),
      5 -> Set(6),
      6 -> Set(5),
      7 -> Set(4, 8),
      8 -> Set(3, 7)
    )
    for
        dfsResult <- DFS(graph2, 1)
        _ <- IO(println(dfsResult))
    yield ()

Solution

  • The visited list should be updated in DFS0 instead of neighborIteration, and neighborIteration should filter the visited nodes while searching. hope this helps.

    import cats.effect._
    import cats._
    import cats.data._
    import cats.syntax.all._
    object DFS extends IOApp.Simple:
      def DFS[T](graph: Map[T, Set[T]], startNode: T): IO[List[(T, Int, Int)]] =
        def neighborIteration(
            neighbors: Set[T],
            visitedRef: Ref[IO, Set[T]],
            visited: Set[T],
            predRef: Ref[IO, List[T]],
            dRef: Ref[IO, Map[T, Int]],
            fRef: Ref[IO, Map[T, Int]],
            timeRef: Ref[IO, Int]
        ): IO[Unit] =
          neighbors.toList
            .map(e =>
              for
                // _ <- IO.println(s"neighbourFunction: $e")
                // _ <- IO.println(s"neighbourFunction visited: $visited")
                localVisited <- visitedRef.get
                _ <-
                  if (!localVisited.contains(e))
                    DFS0(e, visitedRef, predRef, dRef, fRef, timeRef)
                  else IO(())
              yield ()
            )
            .sequence
            .void
        def DFS0(
            node: T,
            visitedRef: Ref[IO, Set[T]],
            predRef: Ref[IO, List[T]],
            dRef: Ref[IO, Map[T, Int]],
            fRef: Ref[IO, Map[T, Int]],
            timeRef: Ref[IO, Int]
        ): IO[Unit] =
          for
            nextTime <- timeRef.getAndUpdate(i => i + 1)
            _ <- dRef.update(m => m + (node -> nextTime))
            pred <- predRef.updateAndGet(lt => node :: lt)
            visited <- visitedRef.updateAndGet(s => s + node)
            // _ <- IO.println(s"node: ${node}, Visited: ${visited}, predRef: ${pred}")
            // _ <- IO.println(graph.getOrElse(node, Set()))
            // _ <- IO.println(s"Visited is $visited")
            // _ <- IO(println(node))
            _ <-
              neighborIteration(
                graph.getOrElse(node, Set()),
                visitedRef,
                visited,
                predRef,
                dRef,
                fRef,
                timeRef
              )
            nextTime2 <- timeRef.getAndUpdate(_ + 1)
            _ <- fRef.update(m => m + (node -> nextTime2))
          yield ()
        val value = for
          predRef <- Ref.of[IO, List[T]](List.empty[T])
          dRef <- Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
          fRef <- Ref.of[IO, Map[T, Int]](Map.empty[T, Int])
          visitedRef <- Ref.of[IO, Set[T]](Set.empty[T])
          timeRef <- Ref.of[IO, Int](0)
          _ <- DFS0(startNode, visitedRef, predRef, dRef, fRef, timeRef)
          predVal <- predRef.get
          dVal <- dRef.get
          fVal <- fRef.get
        yield (predVal, dVal, fVal)
        val result = value.flatMap { (lt, dval, fval) =>
          // IO.println(s"This is the value: ${lt}, ${dval}, ${fval}") *>
            IO(lt.map(e => (e, dval.getOrElse(e, 0), fval.getOrElse(e, 0))))
        }
        result
    
      override def run: IO[Unit] =
    
        val graph = Map(
          1 -> Set(2, 3, 4),
          2 -> Set(1),
          3 -> Set(1, 4),
          4 -> Set(1, 3, 7),
          5 -> Set(6),
          6 -> Set(5),
          7 -> Set(4, 8),
          8 -> Set(3, 7)
        )
        for
          dfsResult <- DFS[Int](graph, 1)
          _ <- IO(println(dfsResult))
        yield ()
    
    

    remove the commented println statements for debugging.