Search code examples
scalagenericsshapeless

Scala: How to Filter Shapless HList based on Labels


I have the following:

sealed trait baseData {
  def weight: Int
  def priority: Int
} 

sealed trait moreData {
  def weight: Int
  def priority: Int
  def t: String
  def id: String
} 

case class data1(override val weight: Int, override val priority: Int) extends baseData 
case class moreData1 (override val weight:Int, override val priority: Int, override val t: String, override val id: String)extends moreData

And am generating HLists from the case classes as such:

val h1 = LabelledGeneric[data1].to(data1(1,2))
val h2 = LabelledGeneric[moreData1].to(moreData1(3,4,"a","b"))

How can I trim or filter h2 so that it only holds fields present in h1? I sense I need to something of the sort val filtered = h2.foldRight(HNil)(keepFunc), but haven't been able to figure out how to write keepFunc. Any ideas?


Solution

  • You could filter based on the keys. The code is not generic but hopefully sufficient to illustrate the concept.

    val gen1 = LabelledGeneric[data1]
    val gen2 = LabelledGeneric[moreData1]
    
    val h1 = gen1.to(data1(1,2))
    val h2 = gen2.to(moreData1(3,4,"a","b"))
    
    val keys1 = Keys[gen1.Repr].apply
    val keys2 = Keys[gen2.Repr].apply
    
    object pair extends Poly2 {
      implicit def default[T, U] = at[T, U]((_, _))
    }
    
    object keep extends Poly2 {
      implicit def keepFunc[T, K, L <: HList] =
        at[(T, K), L] { case ((t, key), l) =>
          if (keys1.toList.contains(key)) t :: l else l
        }
    }
    
    val filtered = h2.zipWith(keys2)(pair).foldRight(HNil)(keep)