Search code examples
scalatype-safetyscala-3path-dependent-type

Scala 3 collection partitioning with subtypes


In Scala 3, let's say I have a List[Try[String]]. Can I split it up into success and failures, such that each list has the appropriate subtype?

If I do the following:

import scala.util.{Try, Success, Failure}
val tries = List(Success("1"), Failure(Exception("2")))
val (successes, failures) = tries.partition(_.isSuccess)

then successes and failures are still of type List[Try[String]]. The same goes if I filter based on the type:

val successes = tries.filter(_.isInstanceOf[Success[String]])

I could of course cast to Success and Failure respectively, but is there a type-safe way to achieve this?


Solution

  • @Luis Miguel Mejía Suárez:

    Use tries.partitionMap(_.toEither)

    @mitchus:

    @LuisMiguelMejíaSuárez ok the trick here is that Try has a toEither method which splits to the proper type. What if we have a regular sealed trait?

    In Scala 2 I would do something like

    import shapeless.{:+:, ::, CNil, Coproduct, Generic, HList, HNil, Inl, Inr, Poly0}
    import shapeless.ops.coproduct.ToHList
    import shapeless.ops.hlist.{FillWith, Mapped, Tupler}
    
    trait Loop[C <: Coproduct, L <: HList] {
      def apply(c: C, l: L): L
    }
    object Loop {
      implicit def recur[H, CT <: Coproduct, HT <: HList](implicit
        loop: Loop[CT, HT]
      ): Loop[H :+: CT, List[H] :: HT] = {
        case (Inl(h), hs :: ht) => (h :: hs) :: ht
        case (Inr(ct), hs :: ht) => hs :: loop(ct, ht)
      }
    
      implicit val base: Loop[CNil, HNil] = (_, l) => l
    }
    
    object nilPoly extends Poly0 {
      implicit def cse[A]: Case0[List[A]] = at(Nil)
    }
    
    def partition[A, C <: Coproduct, L <: HList, L1 <: HList](as: List[A])(implicit
      generic: Generic.Aux[A, C],
      toHList: ToHList.Aux[C, L],
      mapped: Mapped.Aux[L, List, L1],
      loop: Loop[C, L1],
      fillWith: FillWith[nilPoly.type, L1],
      tupler: Tupler[L1]
    ): tupler.Out = {
      val partitionHList: L1 = as.foldRight(fillWith())((a, l1) =>
        loop(generic.to(a), l1)
      )
    
      tupler(partitionHList)
    }
    
    sealed trait A
    case class B(i: Int) extends A
    case class C(i: Int) extends A
    case class D(i: Int) extends A
    
    partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))) 
    // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2))): (List[B], List[C], List[D])
    

    https://scastie.scala-lang.org/DmytroMitin/uQp603sXT7WFYmYntDXmIw/1


    I managed to translate this code into Scala 3 although the translation turned to be wordy (I remplemented Generic and Coproduct)

    import scala.annotation.tailrec
    import scala.deriving.Mirror
    
    object App1 {
      // ============= Generic =====================
      trait Generic[T] {
        type Repr
        def to(t: T): Repr
        def from(r: Repr): T
      }
      object Generic {
        type Aux[T, Repr0] = Generic[T] { type Repr = Repr0 }
        def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
          new Generic[T] {
            override type Repr = Repr0
            override def to(t: T): Repr0 = f(t)
            override def from(r: Repr0): T = g(r)
          }
    
        object ops {
          extension [A](a: A) {
            def toRepr(using g: Generic[A]): g.Repr = g.to(a)
          }
    
          extension [Repr](a: Repr) {
            def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
          }
        }
    
        given [T <: Product](using
          m: Mirror.ProductOf[T]
        ): Aux[T, m.MirroredElemTypes] = instance(
          _.productIterator
           .foldRight[Tuple](EmptyTuple)(_ *: _)
           .asInstanceOf[m.MirroredElemTypes],
          m.fromProduct(_).asInstanceOf[T]
        )
    
        inline given [T, C <: Coproduct](using
          m: Mirror.SumOf[T],
          ev: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
        ): Generic.Aux[T, C] =
          instance(
            matchExpr[T, C](_).asInstanceOf[C],
            Coproduct.unsafeFromCoproduct(_).asInstanceOf[T]
          )
    
        import scala.quoted.*
    
        inline def matchExpr[T, C <: Coproduct](ident: T): Coproduct =
          ${matchExprImpl[T, C]('ident)}
    
        def matchExprImpl[T: Type, C <: Coproduct : Type](
          ident: Expr[T]
        )(using Quotes): Expr[Coproduct] = {
          import quotes.reflect.*
    
          def unwrapCoproduct(typeRepr: TypeRepr): List[TypeRepr] = typeRepr match {
            case AppliedType(_, List(typ1, typ2)) => typ1 :: unwrapCoproduct(typ2)
            case _  => Nil
          }
    
          val typeReprs = unwrapCoproduct(TypeRepr.of[C])
    
          val methodIdent =
            Ident(TermRef(TypeRepr.of[Coproduct.type], "unsafeToCoproduct"))
    
          def caseDefs(ident: Term): List[CaseDef] =
            typeReprs.zipWithIndex.map { (typeRepr, i) =>
              CaseDef(
                Typed(ident, Inferred(typeRepr) /*TypeIdent(typeRepr.typeSymbol)*/),
                None,
                Block(
                  Nil,
                  Apply(
                    methodIdent,
                    List(Literal(IntConstant(i)), ident)
                  )
                )
              )
            }
    
          def matchTerm(ident: Term): Term = Match(ident, caseDefs(ident))
    
          matchTerm(ident.asTerm).asExprOf[Coproduct]
        }
      }
    
      // ============= Coproduct =====================
      sealed trait Coproduct extends Product with Serializable
      sealed trait +:[+H, +T <: Coproduct] extends Coproduct
      final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
      final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
      sealed trait CNil extends Coproduct
    
      object Coproduct {
        def unsafeToCoproduct(length: Int, value: Any): Coproduct =
          (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))
    
        @tailrec
        def unsafeFromCoproduct(c: Coproduct): Any = c match {
          case Inl(h) => h
          case Inr(c) => unsafeFromCoproduct(c)
          case _: CNil => sys.error("impossible")
        }
    
        type ToCoproduct[T <: Tuple] <: Coproduct = T match {
          case EmptyTuple => CNil
          case h *: t => h +: ToCoproduct[t]
        }
    
    //    type ToTuple[C <: Coproduct] <: Tuple = C match {
    //      case CNil => EmptyTuple
    //      case h +: t => h *: ToTuple[t]
    //    }
    
        trait ToTuple[C <: Coproduct] {
          type Out <: Tuple
        }
        object ToTuple {
          type Aux[C <: Coproduct, Out0 <: Tuple] = ToTuple[C] { type Out = Out0 }
          def instance[C <: Coproduct, Out0 <: Tuple]: Aux[C, Out0] =
            new ToTuple[C] { override type Out = Out0 }
    
          given [H, T <: Coproduct](using 
            toTuple: ToTuple[T]
          ): Aux[H +: T, H *: toTuple.Out] = instance
          given Aux[CNil, EmptyTuple] = instance
        }
      }
    }
    
    // different file
    import App1.{+:, CNil, Coproduct, Generic, Inl, Inr}
    
    object App2 {    
      trait Loop[C <: Coproduct, L <: Tuple] {
        def apply(c: C, l: L): L
      }
      object Loop {
        given [H, CT <: Coproduct, HT <: Tuple](using 
          loop: Loop[CT, HT]
        ): Loop[H +: CT, List[H] *: HT] = {
          case (Inl(h), hs *: ht) => (h :: hs) *: ht
          case (Inr(ct), hs *: ht) => hs *: loop(ct, ht)
        }
    
        given Loop[CNil, EmptyTuple] = (_, l) => l
      }
    
      trait FillWithNil[L <: Tuple] {
        def apply(): L
      }
      object FillWithNil {
        given [H, T <: Tuple](using 
          fillWithNil: FillWithNil[T]
        ): FillWithNil[List[H] *: T] = () => Nil *: fillWithNil()
        given FillWithNil[EmptyTuple] = () => EmptyTuple
      }
    
      def partition[A, /*L <: Tuple,*/ L1 <: Tuple](as: List[A])(using
        generic: Generic.Aux[A, _ <: Coproduct],
        toTuple: Coproduct.ToTuple[generic.Repr],
        //ev0: Coproduct.ToTuple[generic.Repr] =:= L, // compile-time NPE
        ev: Tuple.Map[toTuple.Out/*L*/, List] =:= L1,
        loop: Loop[generic.Repr, L1],
        fillWith: FillWithNil[L1]
      ): L1 = as.foldRight(fillWith())((a, l1) => loop(generic.to(a), l1))
    
      sealed trait A
      case class B(i: Int) extends A
      case class C(i: Int) extends A
      case class D(i: Int) extends A
    
      def main(args: Array[String]): Unit = {
        println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
      // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))
      }
    }
    

    Scala 3.0.2


    In the macro (generating pattern matching) Inferred(typeRepr) should be instead of TypeIdent(typeRepr.typeSymbol), otherwise this doesn't work for parametric case classes. Actually the macro can be removed at all if we use mirror.ordinal. Simplified version is

    import scala.deriving.Mirror
    import scala.util.NotGiven
    
    trait Generic[T] {
      type Repr
      def to(t: T): Repr
      def from(r: Repr): T
    }
    
    object Generic {
      type Aux[T, Repr0] = Generic[T] {type Repr = Repr0}
    
      def instance[T, Repr0](f: T => Repr0, g: Repr0 => T): Aux[T, Repr0] =
        new Generic[T] {
          override type Repr = Repr0
          override def to(t: T): Repr0 = f(t)
          override def from(r: Repr0): T = g(r)
        }
    
      object ops {
        extension[A] (a: A) {
          def toRepr(using g: Generic[A]): g.Repr = g.to(a)
        }
    
        extension[Repr] (a: Repr) {
          def to[A](using g: Generic.Aux[A, Repr]): A = g.from(a)
        }
      }
    
      given [T <: Product](using
        // ev: NotGiven[T <:< Tuple],
        // ev1: NotGiven[T <:< Coproduct],
        m: Mirror.ProductOf[T],
        m1: Mirror.ProductOf[m.MirroredElemTypes]
      ): Aux[T, m.MirroredElemTypes] = instance(
        m1.fromProduct(_),
        m.fromProduct(_)
      )
    
      given[T, C <: Coproduct](using
        // ev: NotGiven[T <:< Tuple],
        // ev1: NotGiven[T <:< Coproduct],
        m: Mirror.SumOf[T],
        ev2: Coproduct.ToCoproduct[m.MirroredElemTypes] =:= C
      ): Generic.Aux[T, C/*Coproduct.ToCoproduct[m.MirroredElemTypes]*/] = {
        instance(
          t => Coproduct.unsafeToCoproduct(m.ordinal(t), t).asInstanceOf[C],
          Coproduct.unsafeFromCoproduct(_).asInstanceOf[T]
        )
      }
    }
    
    sealed trait Coproduct extends Product with Serializable
    sealed trait +:[+H, +T <: Coproduct] extends Coproduct
    final case class Inl[+H, +T <: Coproduct](head: H) extends (H +: T)
    final case class Inr[+H, +T <: Coproduct](tail: T) extends (H +: T)
    sealed trait CNil extends Coproduct
    
    object Coproduct {
      def unsafeToCoproduct(length: Int, value: Any): Coproduct =
        (0 until length).foldLeft[Coproduct](Inl(value))((c, _) => Inr(c))
    
      @scala.annotation.tailrec
      def unsafeFromCoproduct(c: Coproduct): Any = c match {
        case Inl(h) => h
        case Inr(c) => unsafeFromCoproduct(c)
        case _: CNil => sys.error("impossible")
      }
    
      type ToCoproduct[T <: Tuple] <: Coproduct = T match {
        case EmptyTuple => CNil
        case h *: t => h +: ToCoproduct[t]
      }
    
      type ToTuple[C <: Coproduct] <: Tuple = C match {
        case CNil => EmptyTuple
        case h +: t => h *: ToTuple[t]
      }
    }
    

    Replacing type classes with compile-time/inline methods and match types

    import scala.compiletime.erasedValue
    
    inline def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = (inline erasedValue[C] match {
      case _: CNil => inline erasedValue[L] match {
        case _: EmptyTuple => EmptyTuple
      }
      case _: (h +: ct) => inline erasedValue[L] match {
        case _: (List[`h`] *: ht) => (c, l) match {
          case (Inl(h_v: `h`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
            (h_v :: hs_v) *: ht_v
          case (Inr(ct_v: `ct`), (hs_v: List[`h`]) *: (ht_v: `ht`)) => 
            hs_v *: loop[ct, ht](ct_v, ht_v)
        }
      }
    }).asInstanceOf[L]
    
    inline def fillWithNil[L <: Tuple]: L = (inline erasedValue[L] match {
      case _: EmptyTuple => EmptyTuple
      case _: (List[h] *: t) => Nil *: fillWithNil[t]
    }).asInstanceOf[L]
    
    type TupleList[C <: Coproduct] = Tuple.Map[Coproduct.ToTuple[C], List]
    
    inline def partition[A](as: List[A])(using
      generic: Generic.Aux[A, _ <: Coproduct]
    ): TupleList[generic.Repr] =
      as.foldRight(fillWithNil[TupleList[generic.Repr]])((a, l1) => loop(generic.to(a), l1))
    
    sealed trait A
    case class B(i: Int) extends A
    case class C(i: Int) extends A
    case class D(i: Int) extends A
    
    @main def test = {
      println(partition(List[A](B(1), B(2), C(1), C(2), D(1), D(2), B(3), C(3))))
      // (List(B(1), B(2), B(3)),List(C(1), C(2), C(3)),List(D(1), D(2)))
    }
    

    Tested in 3.2.0 https://scastie.scala-lang.org/DmytroMitin/940QaiqDQQ2QegCyxTbEIQ/1

    How to access parameter list of case class in a dotty macro


    Alternative implementation of loop

    //Loop[C, L] = L
    type Loop[C <: Coproduct, L <: Tuple] <: Tuple = C match {
      case CNil    => CNilLoop[L]
      case h +: ct => CConsLoop[h, ct, L]
    }
    // match types seem not to support nested type matching
    type CNilLoop[L <: Tuple] <: Tuple = L match {
      case EmptyTuple => EmptyTuple
    }
    type CConsLoop[H, CT <: Coproduct, L <: Tuple] <: Tuple = L match {
      case List[H] *: ht => List[H] *: Loop[CT, ht]
    }
    /*inline*/ def loop0[C <: Coproduct, L <: Tuple](c: C, l: L): Loop[C, L] = /*inline*/ c match {
      case _: CNil => /*inline*/ l match {
        case _: EmptyTuple => EmptyTuple
      }
      case c: (h +: ct) => /*inline*/ l match {
        case l: (List[`h`] *: ht) => (c, l) match {
          case (Inl(h_v/*: `h`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) =>
            (h_v :: hs_v) *: ht_v.asInstanceOf[Loop[ct, ht]]
          case (Inr(ct_v/*: `ct`*/), (hs_v/*: List[`h`]*/) *: (ht_v/*: `ht`*/)) => 
            hs_v *: loop0[ct, ht](ct_v, ht_v)
        }
      }
    }
    /*inline*/ def loop[C <: Coproduct, L <: Tuple](c: C, l: L): L = loop0(c, l).asInstanceOf[L]
    

    Another implementation for Scala 2: Split list of algebraic date type to lists of branches?