Search code examples
scalashapeless

Is there a way to retrieve Nat position of a specific element from Shapeless HList?


Given any HList, e.g. 1 :: "str" :: true :: HNil, is there a (simple) way to find the Nat corresponding to the position of any specific list element, e.g. f(hlist, true) ==> Nat._2? Can the same be done for the element type rather than the element itself, e.g. f[String](hlist) ==> Nat._1? Assuming that each element/type occurs exactly once and its presence in the list is guaranteed.


Solution

  • Try to introduce a type class

    import shapeless.{::, DepFn0, HList, HNil, Nat, Succ}
    import shapeless.nat._
    
    trait Find[L <: HList, A] extends DepFn0 { type Out <: Nat }
    object Find {
      type Aux[L <: HList, A, Out0 <: Nat] = Find[L, A] { type Out = Out0 }
      def instance[L <: HList, A, Out0 <: Nat](x: Out0): Aux[L, A, Out0] = new Find[L, A] {
        type Out = Out0
        override def apply(): Out = x
      }
    
      implicit def head[A, T <: HList]: Aux[A :: T, A, _0] = instance(_0)
      implicit def tail[H, T <: HList, A, N <: Nat](implicit
        find: Aux[T, A, N]
      ): Aux[H :: T, A, Succ[N]] = instance(Succ[N]())
    }
    
    def f[L <: HList, A](l: L, a: A)(implicit find: Find[L, A]): find.Out = find()
    
    def f[A] = new PartiallyApplied[A]
    class PartiallyApplied[A] {
      def apply[L <: HList](l: L)(implicit find: Find[L, A]): find.Out = find()
    }
    
    implicitly[Find.Aux[Int :: String :: Boolean :: HNil, String, _1]]
    implicitly[Find.Aux[Int :: String :: Boolean :: HNil, Boolean, _2]]
    
    val hlist = 1 :: "str" :: true :: HNil
    
    val n = f(hlist, true)
    implicitly[n.N =:= _2]
    
    val m = f[String](hlist)
    implicitly[m.N =:= _1]
    

    Or with standard type classes

    import shapeless.ops.hlist.{Collect, IsHCons, ZipWithIndex}
    import shapeless.{HList, HNil, Nat, Poly1, poly}
    
    trait Second[A] extends Poly1
    object Second {
      implicit def cse[A, N <: Nat]: poly.Case1.Aux[Second[A], (A, N), N] = poly.Case1(_._2)
    }
    
    def f[L <: HList, A, L1 <: HList, L2 <: HList, N <: Nat](l: L, a: A)(implicit
      zipWithIndex: ZipWithIndex.Aux[L, L1],
      collect: Collect.Aux[L1, Second[A], L2],
      isHCons: IsHCons.Aux[L2, N, _]
    ): N = collect(zipWithIndex(l)).head