Search code examples
scalashapelessfs2

Converting a stream of coproduct to a stream of HList - Shapeless and FS2


I have a fs2 Stream Stream[F, C] where C <: Coproduct. And I want to transform it into a Stream[F, H] where H <: HList. This HList should contain all members that the coproduct C had.

So, essentially, a Pipe[F, C, H] .

The fs2 Pipe will work by waiting for at least one of each of the coproduct's members to be pulled, and then once at least one of each are pulled, finally combine it into a HList and output it.

So, it will be used pretty much like so:

type MyCoprod = A :+: B :+: C :+: CNil
type MyHList = A :: B :: C :: HNil

val stream: Stream[F, MyHList] = Stream
  .emits(List(A, B, C)) // my coproducts
  .through(pullAll) // i want to wait for A, B, C to pulled at least once and outputted 
  .map { hlist => ... }

I am very very new to Shapeless, and this is what I could think of before hitting a roadblock:

trait WaitFor[F[_], C <: Coproduct] {
  type Out <: HList

  def apply: Pipe[F, C, Out]
}

object WaitFor {
  type Aux[F[_], C <: Coproduct, Out0 <: HList] =
    WaitFor[F, C] { type Out = Out0 }

  implicit def make[F[_], C <: Coproduct, L <: HList](implicit
    toHList: ToHList.Aux[C, L]
  ): Aux[F, C, L] = new WaitFor.Aux[F, C, L] {
    override type Out = L

    override def apply: Pipe[F, C, Out] = {
      def go(s2: Stream[F, C], currHList: L): Pull[F, L, Unit] = {
        s2.pull.uncons1.flatMap {
          case Some((coproduct, s3)) => {
            // add or update coproduct member to currHList

            // if currHList is the same as L (our output type) then output it (Pull.output1(currHList)) and clear currHList

            // if not, keep iterating:

            go(s3, ???)
          }

          case None => Pull.done
        }
      }
      go(s1, ???).stream
    }
  }

  def pullAll[F[_], C <: Coproduct](
    stream: Stream[F, C]
  )(implicit ev: WaitFor[F, C]): Stream[F, ev.Out] = {
    stream.through(ev.apply)
  }
}

My roadblock starts here:

    override def apply: Pipe[F, C, Out] = ???

and that's when my knowledge of Shapeless exhausts.

My idea is to keep track of all coproduct members in a tuple (Option[C1], Option[C2], ...).

Once every element in the tuple is Some, I'll covert them to a HList and output them in the Stream.

(I'll be using FS2 Pull to keep track of the state recursively so I'm not worried about that).

But my issue is that, at the value level, there's no way for me to know how long the tuple will be, and for me to construct a tuple.

Any pointers so I can solve this?

Thanks


Solution

  • Let's do it step by step:

    • your input will be A :+: B :+: C :+: CNil
    • you will store somewhere: newest A, newest B etc
    • initially there won't be any newest value
    • after finding all values you should emit A :: B :: C :: HNil
    • when you are emitting new HList value, you should also reset your intermediate values storage
    • that suggest that it would be handy to store these intermediate values as Option[A] :: Option[B] :: Option[C] :: HNil

    So, let's write a type class which would help us with it:

    import shapeless._
    
    // A type class for collecting Coproduct elements (last-wins)
    // until they could be combined into an HList element
    
    // Path-dependent types and Aux for better DX, e.g. when one
    // would want Collector[MyType] without manually entering HLists
    trait Collector[Input] {
    
      type Cache
      type Result
    
      // pure computation of an updated cache
      def updateState(newInput: Input, currentState: Cache): Cache
    
      // returns Some if all elements of Cache are Some, None otherwise
      def attemptConverting(updatedState: Cache): Option[Result]
    
      // HLists of Nones
      def emptyCache: Cache
    }
    object Collector {
    
      type Aux[Input, Cache0, Result0] = Collector[Input] {
        type Cache = Cache0
        type Result = Result0
      }
    
      def apply[Input](implicit
          collector: Collector[Input]
      ): Collector.Aux[Input, collector.Cache, collector.Result] =
        collector
    
      // obligatory empty Coproduct/HList case to terminate recursion
      implicit val nilCollector: Collector.Aux[CNil, HNil, HNil] =
        new Collector[CNil] {
    
          type Cache = HNil
          type Result = HNil
    
          override def updateState(newInput: CNil, currentState: HNil): HNil = HNil
    
          override def attemptConverting(updatedState: HNil): (Option[HNil]) =
            Some(HNil)
    
          override def emptyCache: HNil = HNil
        }
    
      // here we define the actual recursive derivation
      implicit def consCollector[
          Head,
          InputTail <: Coproduct,
          CacheTail <: HList,
          ResultTail <: HList
      ](implicit
          tailCollector: Collector.Aux[InputTail, CacheTail, ResultTail]
      ): Collector.Aux[
          Head :+: InputTail,
          Option[Head] :: CacheTail,
          Head :: ResultTail
      ] = new Collector[Head :+: InputTail] {
    
          type Cache = Option[Head] :: CacheTail
          type Result = Head :: ResultTail
    
          override def updateState(
              newInput: Head :+: InputTail,
              currentState: Option[Head] :: CacheTail
          ): Option[Head] :: CacheTail = newInput match {
            case Inl(head) => Some(head) :: currentState.tail
            case Inr(tail) =>
              currentState.head :: tailCollector.updateState(
                tail,
                currentState.tail
              )
          }
    
          override def attemptConverting(
              updatedState: Option[Head] :: CacheTail
          ): Option[Head :: ResultTail] = for {
            head <- updatedState.head
            tail <- tailCollector.attemptConverting(updatedState.tail)
          } yield head :: tail
    
          override def emptyCache: Option[Head] :: CacheTail =
            None :: tailCollector.emptyCache
        }
    }
    

    This code doesn't assume how we would store our cache not how we would update it. So we might test it with some impure code:

    import shapeless.ops.coproduct.Inject
    
    type Input = String :+: Int :+: Double :+: CNil
    val collector = Collector[Input]
    
    // dirty, but good enough for demo
    var cache = collector.emptyCache
    
    LazyList[Input](
      Inject[Input, String].apply("test1"),
      Inject[Input, String].apply("test2"),
      Inject[Input, String].apply("test3"),
      Inject[Input, Int].apply(1),
      Inject[Input, Int].apply(2),
      Inject[Input, Int].apply(3),
      Inject[Input, Double].apply(3),
      Inject[Input, Double].apply(4),
      Inject[Input, Double].apply(3),
      Inject[Input, String].apply("test4"),
      Inject[Input, Int].apply(4),
    ).foreach { input =>
      val newCache = collector.updateState(input, cache)
      collector.attemptConverting(newCache) match {
        case Some(value) =>
          println(s"Product computed: value!")
          cache = collector.emptyCache
        case None =>
          cache = newCache
      }
      println(s"Current cache: $cache")
    }
    

    We can check with Scaste that it prints what we expect it would.

    Current cache: Some(test1) :: None :: None :: HNil
    Current cache: Some(test2) :: None :: None :: HNil
    Current cache: Some(test3) :: None :: None :: HNil
    Current cache: Some(test3) :: Some(1) :: None :: HNil
    Current cache: Some(test3) :: Some(2) :: None :: HNil
    Current cache: Some(test3) :: Some(3) :: None :: HNil
    Product computed: test3 :: 3 :: 3.0 :: HNil!
    Current cache: None :: None :: None :: HNil
    Current cache: None :: None :: Some(4.0) :: HNil
    Current cache: None :: None :: Some(3.0) :: HNil
    Current cache: Some(test4) :: None :: Some(3.0) :: HNil
    Product computed: test4 :: 4 :: 3.0 :: HNil!
    Current cache: None :: None :: None :: HNil
    

    Now, it's a matter of how we'll thread this intermediate result through the FS2 Stream. One way would be to use Ref

    for {
      // for easy passing of cache around
      cacheRef <- Stream.eval(Ref[IO].of(collector.emptyCache))
      // source of Coproducts
      input <- Stream[IO, Input](
        Inject[Input, String].apply("test1"),
        Inject[Input, String].apply("test2"),
        Inject[Input, String].apply("test3"),
        Inject[Input, Int].apply(1),
        Inject[Input, Int].apply(2),
        Inject[Input, Int].apply(3),
        Inject[Input, Double].apply(3)
      )
      updateCache = cacheRef.modify[Stream[IO, collector.Result]] { cache =>
        val newCache = collector.updateState(input, cache)
        collector.attemptConverting(newCache) match {
          case Some(value) => collector.emptyCache -> Stream(value)
          case None        => newCache -> Stream.empty
        }
      }
      // emits new HList only if all of its elements has been gathered 
      hlist <- Stream.eval(updateCache).flatten
    } yield hlist
    

    One might modify this code to fit their aesthetics: extract updateCache to some function, use state monad or whatever. I guess turning it into pipe would be, e.g.:

    // you might replace cats.effect.IO with F[_]: Monad, use something
    // else instead of Ref, or whatever
    def collectCoproductsToHList[Input](
        implicit collector: Collector[Input]
    ): IO[Pipe[IO, Input, collector.Result]] = 
      Ref[IO].of(collector.emptyCache).map { cacheRef =>
          
        val pipe: Pipe[IO, Input, collector.Result] = inputStream => for {
          input <- inputStream
          updateCache = cacheRef.modify[Stream[IO, collector.Result]] { cache =>
            val newCache = collector.updateState(input, cache)
            collector.attemptConverting(newCache) match {
              case Some(value) => collector.emptyCache -> Stream(value)
              case None        => newCache             -> Stream.empty
            }
          }
          hlist <- Stream.eval(updateCache).flatten
        } yield hlist
          
        pipe
      }