Search code examples
scalashapelesshlist

Split list of algebraic date type to lists of branches?


I'm pretty new to shapeless so the question might be easy.

Here is the ADT:

sealed trait Test

final case class A() extends Test
final case class B() extends Test
final case class C() extends Test
...
final case class Z() extends Test

Is it possible to write a function without extremely cumbersome pattern matching?

def split(lst: List[Test]): List[A] :: List[B] :: ... :: HNil = //

Solution

  • At compile time all elements of List have the same static type Test, so there is no way to distinguish elements A, B, C... using compile-time technique only (Shapeless, type classes, implicits, macros, compile-time reflection). The elements are distinguishable at runtime only, so you have to use some runtime technique (pattern matching, casting, runtime reflection).

    Why Does This Type Constraint Fail for List[Seq[AnyVal or String]]

    Scala: verify class parameter is not instanceOf a trait at compile time

    flatMap with Shapeless yield FlatMapper not found

    Try split into a Map using runtime reflection

    def split(lst: List[Test]): Map[String, List[Test]]  =
      lst.groupBy(_.getClass.getSimpleName)
    
    split(List(C(), B(), A(), C(), B(), A()))
    // HashMap(A -> List(A(), A()), B -> List(B(), B()), C -> List(C(), C()))
    

    or split into a HList using Shapeless + runtime reflection

    import shapeless.labelled.{FieldType, field}
    import shapeless.{::, Coproduct, HList, HNil, LabelledGeneric, Poly1, Typeable, Witness}
    import shapeless.ops.coproduct.ToHList
    import shapeless.ops.hlist.Mapper
    import shapeless.ops.record.Values
    import shapeless.record._
    import scala.annotation.implicitNotFound
        
    object listPoly extends Poly1 {
      implicit def cse[K <: Symbol, V]: Case.Aux[FieldType[K, V], FieldType[K, List[V]]] = null
    }
    
    // modified shapeless.ops.maps.FromMap
    @implicitNotFound("Implicit not found: FromMapWithDefault[${R}]. Maps can only be converted to appropriate Record types.")
    trait FromMapWithDefault[R <: HList] extends Serializable {
      // if no value by this key use default, if can't cast return None
      def apply[K, V](m: Map[K, V], default: V): Option[R]
    }
    object FromMapWithDefault {
      implicit def hnilFromMap[T]: FromMapWithDefault[HNil] =
        new FromMapWithDefault[HNil] {
          def apply[K, V](m: Map[K, V], default: V): Option[HNil] = Some(HNil)
        }
    
    
      implicit def hlistFromMap[K0, V0, T <: HList]
      (implicit wk: Witness.Aux[K0], tv: Typeable[V0], fmt: FromMapWithDefault[T]): FromMapWithDefault[FieldType[K0, V0] :: T] =
        new FromMapWithDefault[FieldType[K0, V0] :: T] {
          def apply[K, V](m: Map[K, V], default: V): Option[FieldType[K0, V0] :: T] = {
            val value = m.getOrElse(wk.value.asInstanceOf[K], default)
            for {
              typed <- tv.cast(value)
              rest <- fmt(m, default)
            } yield field[K0](typed) :: rest
          }
        }
    }
    
    def split[T, C <: Coproduct, L <: HList, L1 <: HList](lst: List[T])(
      implicit
      labelledGeneric: LabelledGeneric.Aux[T, C],
      toHList: ToHList.Aux[C, L],
      mapper: Mapper.Aux[listPoly.type, L, L1],
      fromMapWithDefault: FromMapWithDefault[L1],
      values: Values[L1]
    ): values.Out = {
      val groupped = lst.groupBy(_.getClass.getSimpleName).map { case (k, v) => Symbol(k) -> v }
      fromMapWithDefault(groupped, Nil).get.values
    }
    

    Testing:

    sealed trait Test
    final case class A() extends Test
    final case class B() extends Test
    final case class C() extends Test
    final case class Z() extends Test
    
    val res = split(List[Test](C(), B(), A(), C(), B(), A())) 
    // List(A(), A()) :: List(B(), B()) :: List(C(), C()) :: List() :: HNil
    res: List[A] :: List[B] :: List[C] :: List[Z] :: HNil
    

    Scala 3 collection partitioning with subtypes (Scala 2/3)