Search code examples
scalashapeless

Type-level filtering using shapeless


Does anybody know how to make this test work using Shapeless.

package net.jtownson.swakka.jsonschema

import org.scalatest.FlatSpec
import org.scalatest.Matchers._

class OptionalFieldSpec extends FlatSpec {

  case class A(i: Int, j: Option[Int])

  "an extractor of some kind" should "get the (non)optional fields from a case class" in {

    extractNonOptionalFieldNames[A] shouldBe List("i")

    extractOptionalFieldNames[A] shouldBe List("j")

  }

  def extractNonOptionalFieldNames[T <: Product](/* implicit typeclass instances? */): List[String] = ???

  def extractOptionalFieldNames[T <: Product]: List[String] = ???

}

I have no runtime instance of A or its Generic equivalent as I am working at creating a JsonSchema for case class A, which is independent of any given instance. The schema has a required field, which is a list of the non-optional fields. e.g.

{
  "type" -> "object",
  "required" -> ["i"],
  "properties" -> {
    "i" -> {
      "type" -> "integer",
      "format" -> "int32"
     }
   }
}

Solution

  • Something like this:

    trait FieldNameExtractor[T] extends Serializable {
      import shapeless.ops.hlist.{RightFolder, ToTraversable}
      import shapeless.ops.record.Keys
      import shapeless.{HList, HNil, LabelledGeneric, Poly2}
    
      /**
        * Extracts filtered field names for type [[T]],
        * given a polymorphic function that acts as the type filter
        */
      def extract[L <: HList, R <: HList, O <: HList](op: Poly2)(
          implicit lgen: LabelledGeneric.Aux[T, L],
          folder: RightFolder.Aux[L, HNil.type, op.type, R],
          keys: Keys.Aux[R, O],
          traversable: ToTraversable.Aux[O, List, Symbol]
      ): List[String] = {
        val result = keys().to[List]
        result.map(_.name)
      }
    }
    
    object FieldNameExtractor {
      def apply[T] = new FieldNameExtractor[T] {}
    }
    

    Usage:

    import org.scalatest.FlatSpec
    import org.scalatest.Matchers._
    
    class Test extends FlatSpec {
      /* type filters */
      import shapeless.{HList, Poly2}
      import shapeless.labelled.KeyTag, shapeless.tag.Tagged
    
      type FilterO[A, T] = Option[A] with KeyTag[Symbol with Tagged[T], Option[A]]
    
      trait Ignore extends Poly2 {
        implicit def default[A, L <: HList] = at[A, L]((_, l) => l)
      }
      trait Accept extends Poly2 {
        implicit def default[A, L <: HList] = at[A, L](_ :: _)
      }
    
      object allOptions extends Ignore {
        implicit def option[A, T, L <: HList] = at[FilterO[A, T], L](_ :: _)
      }
      object noOptions extends Accept {
        implicit def option[A, T, L <: HList] = at[FilterO[A, T], L]((_, l) => l)
      }
    
      "an extractor of some kind" should "get the (non)optional fields from a case class" in {
        case class A(i: Int, j: Option[Int], k: String)
    
        val fne = FieldNameExtractor[A]
        fne.extract(noOptions) shouldBe List("i", "k") // extractNonOptionalFieldNames
        fne.extract(allOptions) shouldBe List("j")     // extractOptionalFieldNames
      }
    }