Search code examples
scalashapeless

Transform a HList of Eithers to an Either of a HList


I want to define a function that accepts a HList whose elements are such that, for each element t, there is a type T such that t: Either[String, T]. The function, which we will call validate, should have the following behaviour:

  • If all elements of the parameter are Right, return Right of the result of mapping the parameter with right-projection.
  • Otherwise, return a Left[List[String]], where the list contains the left-projection for each Left in the parameter.

Examples:

validate (Right (42) :: Right (3.14) :: Right (false) :: HNil)
>> Right (42 :: 3.14 :: false :: HNil)
validate (Right (42) :: Left ("qwerty") :: Left ("uiop") :: HNil)
>> Left (List ("qwerty", "uiop"))

An example use case:

case class Result (foo: Foo, bar: Bar, baz: Baz, qux: Qux)

def getFoo: Either[String, Foo] = ???
def getBar: Either[String, Bar] = ???
def getBaz: Either[String, Baz] = ???
def getQux: Either[String, Qux] = ???

def createResult: Either[String, Result] = {
    validate (getFoo :: getBar :: getBaz :: getQux :: HNil) match {
        case Right (foo :: bar :: baz :: qux :: HNil) => Right (Result (foo, bar, baz, qux))
        case Left (errors) => Left ("The following errors occurred:\n" + errors.mkString ("\n"))
    }
}

Solution

  • I'll assume we have some test data like this throughout this answer:

    scala> import shapeless.{::, HNil}
    import shapeless.{$colon$colon, HNil}
    
    scala> type In = Either[String, Int] :: Either[String, String] :: HNil
    defined type alias In
    
    scala> val good: In = Right(123) :: Right("abc") :: HNil
    good: In = Right(123) :: Right(abc) :: HNil
    
    scala> val bad: In = Left("error 1") :: Left("error 2") :: HNil
    bad: In = Left(error 1) :: Left(error 2) :: HNil
    

    Using a custom type class

    There are many ways you could do this. I'd probably use a custom type class that highlights the way instances are built up inductively:

    import shapeless.HList
    
    trait Sequence[L <: HList] {
      type E
      type Out <: HList
      def apply(l: L): Either[List[E], Out]
    }
    
    object Sequence {
      type Aux[L <: HList, E0, Out0 <: HList] = Sequence[L] { type E = E0; type Out = Out0 }
    
      implicit def hnilSequence[E0]: Aux[HNil, E0, HNil] = new Sequence[HNil] {
        type E = E0
        type Out = HNil
        def apply(l: HNil): Either[List[E], HNil] = Right(l)
      }
    
      implicit def hconsSequence[H, T <: HList, E0](implicit
        ts: Sequence[T] { type E = E0 }
      ): Aux[Either[E0, H] :: T, E0, H :: ts.Out] = new Sequence[Either[E0, H] :: T] {
        type E = E0
        type Out = H :: ts.Out
        def apply(l: Either[E0, H] :: T): Either[List[E0], H :: ts.Out] =
          (l.head, ts(l.tail)) match {
            case (Right(h), Right(t)) => Right(h :: t)
            case (Left(eh), Left(et)) => Left(eh :: et)
            case (Left(eh), _) => Left(List(eh))
            case (_, Left(et)) => Left(et)
          }
      }
    }
    

    Then you can write validate like this:

    def validate[L <: HList](l: L)(implicit s: Sequence[L]): Either[List[s.E], s.Out] = s(l)
    

    And use it like this:

    scala> validate(good)
    res0: scala.util.Either[List[String],Int :: String :: shapeless.HNil] = Right(123 :: abc :: HNil)
    
    scala> validate(bad)
    res1: scala.util.Either[List[String],Int :: String :: shapeless.HNil] = Left(List(error 1, error 2))
    

    Note that the static types come out right.

    Using a right fold

    You could also do it a little more concisely by folding with a Poly2.

    import shapeless.Poly2
    
    object combine extends Poly2 {
      implicit def eitherCase[H, T, E, OutT <: HList]:
        Case.Aux[Either[E, H], Either[List[E], OutT], Either[List[E], H :: OutT]] = at {
          case (Right(h), Right(t)) => Right(h :: t)
          case (Left(eh), Left(et)) => Left(eh :: et)
          case (Left(eh), _) => Left(List(eh))
          case (_, Left(et)) => Left(et) 
        }
    }
    

    And then:

    scala> good.foldRight(Right(HNil): Either[List[String], HNil])(combine)
    res2: scala.util.Either[List[String],Int :: String :: shapeless.HNil] = Right(123 :: abc :: HNil)
    
    scala> bad.foldRight(Right(HNil): Either[List[String], HNil])(combine)
    res3: scala.util.Either[List[String],Int :: String :: shapeless.HNil] = Left(List(error 1, error 2))
    

    I guess this is probably the "right" answer, assuming you want to stick to Shapeless alone. The Poly2 approach just relies on some weird details of implicit resolution (we couldn't define combine as a val, for example) that I personally don't really like.

    Using Kittens's sequence

    Lastly you could use the Kittens library, which supports sequencing and traversing hlists:

    scala> import cats.instances.all._, cats.sequence._
    import cats.instances.all._
    import cats.sequence._
    
    scala> good.sequence
    res4: scala.util.Either[String,Int :: String :: shapeless.HNil] = Right(123 :: abc :: HNil)
    
    scala> bad.sequence
    res5: scala.util.Either[String,Int :: String :: shapeless.HNil] = Left(error 1)
    

    Note that this doesn't accumulate errors, though.

    If you wanted the most complete possible Typelevel experience I guess you could add a parSequence operation to Kittens that would accumulate errors for an hlist of eithers via the Parallel instance mapping them to Validated (see my blog post here for more detail about how this works). Kittens doesn't currently include this, though.

    Update: parallel sequencing

    If you want parSequence, it's not actually that much of a nightmare to write it yourself:

    import shapeless.HList, shapeless.poly.~>, shapeless.ops.hlist.{Comapped, NatTRel}
    import cats.Parallel, cats.instances.all._, cats.sequence.Sequencer
    
    def parSequence[L <: HList, M[_], P[_], PL <: HList, Out](l: L)(implicit
      cmp: Comapped[L, M],
      par: Parallel.Aux[M, P],
      ntr: NatTRel[L, M, PL, P],
      seq: Sequencer.Aux[PL, P, Out]
    ): M[Out] = {
      val nt = new (M ~> P) {
        def apply[A](a: M[A]): P[A] = par.parallel(a)
      }
    
      par.sequential(seq(ntr.map(nt, l)))
    }
    

    And then:

    scala> parSequence(good)
    res0: Either[String,Int :: String :: shapeless.HNil] = Right(123 :: abc :: HNil)
    
    scala> parSequence(bad)
    res1: Either[String,Int :: String :: shapeless.HNil] = Left(error 1error 2)
    

    Note that this does accumulate errors, but by concatenating the strings. The Cats-idiomatic way to accumulate errors in a list would look like this:

    scala> import cats.syntax.all._
    import cats.syntax.all._
    
    scala> val good = 123.rightNel[String] :: "abc".rightNel[String] :: HNil
    good: Either[cats.data.NonEmptyList[String],Int] :: Either[cats.data.NonEmptyList[String],String] :: shapeless.HNil = Right(123) :: Right(abc) :: HNil
    
    scala> val bad = "error 1".leftNel[String] :: "error 2".leftNel[Int] :: HNil
    bad: Either[cats.data.NonEmptyList[String],String] :: Either[cats.data.NonEmptyList[String],Int] :: shapeless.HNil = Left(NonEmptyList(error 1)) :: Left(NonEmptyList(error 2)) :: HNil
    
    scala> parSequence(good)
    res3: Either[cats.data.NonEmptyList[String],Int :: String :: shapeless.HNil] = Right(123 :: abc :: HNil)
    
    scala> parSequence(bad)
    res4: Either[cats.data.NonEmptyList[String],String :: Int :: shapeless.HNil] = Left(NonEmptyList(error 1, error 2))
    

    It'd probably be worth opening a PR to add something like this to Kittens.