Search code examples
scalafs2

How to group large stream into sub streams


I want to group large Stream[F, A] into Stream[Stream[F, A]] with at most n element for inner stream.

This is what I did, basically pipe chunks into Queue[F, Queue[F, Chunk[A]], and then yields queue elements as result stream.

 implicit class StreamSyntax[F[_], A](s: Stream[F, A])(
    implicit F: Concurrent[F]) {

    def groupedPipe(
      lastQRef: Ref[F, Queue[F, Option[Chunk[A]]]],
      n: Int): Pipe[F, A, Stream[F, A]] = { in =>
      val initQs =
        Queue.unbounded[F, Option[Queue[F, Option[Chunk[A]]]]].flatMap { qq =>
          Queue.bounded[F, Option[Chunk[A]]](1).flatMap { q =>
            lastQRef.set(q) *> qq.enqueue1(Some(q)).as(qq -> q)
          }
        }

      Stream.eval(initQs).flatMap {
        case (qq, initQ) =>
          def newQueue = Queue.bounded[F, Option[Chunk[A]]](1).flatMap { q =>
            qq.enqueue1(Some(q)) *> lastQRef.set(q).as(q)
          }

          val evalStream = {
            in.chunks
              .evalMapAccumulate((0, initQ)) {
                case ((i, q), c) if i + c.size >= n =>
                  val (l, r) = c.splitAt(n - i)
                  q.enqueue1(Some(l)) >> q.enqueue1(None) >> q
                    .enqueue1(None) >> newQueue.flatMap { nq =>
                    nq.enqueue1(Some(r)).as(((r.size, nq), c))
                  }
                case ((i, q), c) if (i + c.size) < n =>
                  q.enqueue1(Some(c)).as(((i + c.size, q), c))
              }
              .attempt ++ Stream.eval {
              lastQRef.get.flatMap { last =>
                last.enqueue1(None) *> last.enqueue1(None)
              } *> qq.enqueue1(None)
            }
          }
          qq.dequeue.unNoneTerminate
            .map(
              q =>
                q.dequeue.unNoneTerminate
                  .flatMap(Stream.chunk)
                  .onFinalize(
                    q.dequeueChunk(Int.MaxValue).unNoneTerminate.compile.drain))
            .concurrently(evalStream)
      }
    }

    def grouped(n: Int) = {
      Stream.eval {
        Queue.unbounded[F, Option[Chunk[A]]].flatMap { empty =>
          Ref.of[F, Queue[F, Option[Chunk[A]]]](empty)
        }
      }.flatMap { ref =>
        val p = groupedPipe(ref, n)
        s.through(p)
      }
    }
  }

But it is very complicated, is there any simpler way ?


Solution

  • Finally I use a more reliable version (use Hotswap ensure queue termination) like this.

      def grouped(
          innerSize: Int
        )(implicit F: Async[F]): Stream[F, Stream[F, A]] = {
    
          type InnerQueue = Queue[F, Option[Chunk[A]]]
          type OuterQueue = Queue[F, Option[InnerQueue]]
    
          def swapperInner(swapper: Hotswap[F, InnerQueue], outer: OuterQueue) = {
            val innerRes =
              Resource.make(Queue.unbounded[F, Option[Chunk[A]]])(_.offer(None))
            swapper.swap(innerRes).flatTap(q => outer.offer(q.some))
          }
    
          def loopChunk(
            gathered: Int,
            curr: Queue[F, Option[Chunk[A]]],
            chunk: Chunk[A],
            newInnerQueue: F[InnerQueue]
          ): F[(Int, Queue[F, Option[Chunk[A]]])] = {
            if (gathered + chunk.size > innerSize) {
              val (left, right) = chunk.splitAt(innerSize - gathered)
              curr.offer(left.some) >> newInnerQueue.flatMap { nq =>
                loopChunk(0, nq, right, newInnerQueue)
              }
            } else if (gathered + chunk.size == innerSize) {
              curr.offer(chunk.some) >> newInnerQueue.tupleLeft(
                0
              )
            } else {
              curr.offer(chunk.some).as(gathered + chunk.size -> curr)
            }
          }
    
          val prepare = for {
            outer   <- Resource.eval(Queue.unbounded[F, Option[InnerQueue]])
            swapper <- Hotswap.create[F, InnerQueue]
          } yield outer -> swapper
    
          Stream.resource(prepare).flatMap {
            case (outer, swapper) =>
              val newInner = swapperInner(swapper, outer)
              val background = Stream.eval(newInner).flatMap { initQueue =>
                s.chunks
                  .filter(_.nonEmpty)
                  .evalMapAccumulate(0 -> initQueue) { (state, chunk) =>
                    val (gathered, curr) = state
                    loopChunk(gathered, curr, chunk, newInner).tupleRight({})
                  }
                  .onFinalize(swapper.clear *> outer.offer(None))
              }
              val foreground = Stream
                .fromQueueNoneTerminated(outer)
                .map(i => Stream.fromQueueNoneTerminatedChunk(i))
              foreground.concurrently(background)
          }
    
        }