Search code examples
scalaenumsscala-macrosscala-3

Is it possible to match a quoted type of a sum type?


I am doing type class derivation in Scala3. My typeclass is BarOf[T] extending Bar. T is expected to be an ADT. Bar has a children method returning the BarOf[X] of the children types of T.

Here the derived method is based on Quotes rather than on Mirrors. This is needed for the rest of my code, which has been removed here.

The essential step to find out the BarOf[X] (for each of the X children types) is based on a quoted type match. (Full code at the end of the post)

A.memberType(childSymbol).asType match {
        case '[f] => Expr.summon[BarOf[f]]
      }

This works fine for product types and particularly case classes. But, when applying to sum types, the compiler fires an assertion :

assertion failure for class C1 <:< f, frozen = false
6 |sealed class Foo1 derives BarOf
  |                         ^
  |Exception occurred while executing macro expansion.
  |java.lang.AssertionError: assertion failed: ClassInfo(ThisType(TypeRef(NoPrefix,module class <empty>)), class C0, List(TypeRef(ThisType(TypeRef(NoPrefix,module class <empty>)),class Foo)))
  |     at scala.runtime.Scala3RunTime$.assertFailed(Scala3RunTime.scala:8)
  |     at dotty.tools.dotc.core.Types$TypeBounds.<init>(Types.scala:4918)
  |     at dotty.tools.dotc.core.Types$RealTypeBounds.<init>(Types.scala:4994)
  |     at dotty.tools.dotc.core.Types$TypeBounds$.apply(Types.scala:5038)
  |     at dotty.tools.dotc.core.Types$TypeBounds.derivedTypeBounds(Types.scala:4926)
  |     at dotty.tools.dotc.core.ConstraintHandling.addOneBound(ConstraintHandling.scala:116)
  |     at dotty.tools.dotc.core.ConstraintHandling.addOneBound$(ConstraintHandling.scala:27)
  |     at dotty.tools.dotc.core.ProperGadtConstraint.addOneBound(GadtConstraint.scala:60)
  |     at dotty.tools.dotc.core.ConstraintHandling.addBoundTransitively(ConstraintHandling.scala:170)
  |     at dotty.tools.dotc.core.ConstraintHandling.addBoundTransitively$(ConstraintHandling.scala:27)
  |     at dotty.tools.dotc.core.ProperGadtConstraint.addBoundTransitively(GadtConstraint.scala:60)
  |     at dotty.tools.dotc.core.ProperGadtConstraint.addBound(GadtConstraint.scala:159)
  |     at dotty.tools.dotc.core.TypeComparer.gadtAddLowerBound(TypeComparer.scala:118)
  |     at dotty.tools.dotc.core.TypeComparer.narrowGADTBounds(TypeComparer.scala:1913)
  |     at dotty.tools.dotc.core.TypeComparer.compareGADT$1(TypeComparer.scala:521)
  |     at dotty.tools.dotc.core.TypeComparer.thirdTryNamed$1(TypeComparer.scala:524)
  |     at dotty.tools.dotc.core.TypeComparer.thirdTry$1(TypeComparer.scala:573)
  |     at dotty.tools.dotc.core.TypeComparer.secondTry$1(TypeComparer.scala:504)
  |     at dotty.tools.dotc.core.TypeComparer.compareNamed$1(TypeComparer.scala:313)
  |     at dotty.tools.dotc.core.TypeComparer.firstTry$1(TypeComparer.scala:319)
  |     at dotty.tools.dotc.core.TypeComparer.recur(TypeComparer.scala:1321)
  |     at dotty.tools.dotc.core.TypeComparer.isSubType(TypeComparer.scala:201)
  |     at dotty.tools.dotc.core.TypeComparer.isSubType(TypeComparer.scala:211)
  |     at dotty.tools.dotc.core.TypeComparer.topLevelSubType(TypeComparer.scala:128)
  |     at dotty.tools.dotc.core.TypeComparer$.topLevelSubType(TypeComparer.scala:2729)
  |     at dotty.tools.dotc.core.Types$Type.$less$colon$less(Types.scala:1035)
  |     at scala.quoted.runtime.impl.QuoteMatcher$.$eq$qmark$eq(QuoteMatcher.scala:336)
  |     at scala.quoted.runtime.impl.QuoteMatcher$.treeMatch(QuoteMatcher.scala:129)
  |     at scala.quoted.runtime.impl.QuotesImpl.scala$quoted$runtime$impl$QuotesImpl$$treeMatch(QuotesImpl.scala:3021)
  |     at scala.quoted.runtime.impl.QuotesImpl$TypeMatch$.unapply(QuotesImpl.scala:2991)
  |     at bar.BarOf$.childExpr$1(bar.scala:21)
  |     at bar.BarOf$.$anonfun$6(bar.scala:26)
  |     at scala.collection.immutable.List.map(List.scala:246)
  |     at bar.BarOf$.derivedImpl(bar.scala:26)

The assertion is triggered by the quoted type match : case '[f] =>. But I cannot figure out what I am missing. Is my code legal ?

  • if no : what is the alternative ?
  • if yes : is it compiler bug ?

Full bar package source:

package bar
import scala.quoted.*


trait Bar:
  def children : List[Bar]

trait BarOf[A] extends Bar

object BarOf :
  inline def derived[A]: BarOf[A] =  ${derivedImpl[A]}

  def derivedImpl[A:Type](using Quotes):Expr[BarOf[A]] =
    import quotes.reflect._
    val A           = TypeRepr.of[A]
    val mySymbol    = A.typeSymbol
    val terms       = mySymbol.children

    def childExpr(childSymbol:Symbol) =
      A.memberType(childSymbol).asType match {
        case '[f] => Expr.summon[BarOf[f]]
      } getOrElse {
        report.errorAndAbort(s"ChildType ${childSymbol} of ${mySymbol} does not derive BarOf")
      }

    val termsExpr =  Expr.ofList(terms.map(childExpr))
    '{  new BarOf[A]
        {   def children = ${termsExpr}
        }
     }

Test Code:

import bar.BarOf
  
// works
enum Foo0 derives BarOf :
  case C0

// does not work 
enum Foo1 derives BarOf :
  case C1(x:Int)

// does not work either :
sealed class Foo2 derives BarOf
class C2(x:Int) extends Foo2



@main def run:Unit =
  val x = summon[BarOf[Foo1]]
  println(x.children)

Solution

  • The issue was not in the type pattern matching as I thought, but on the symbol being matched.

    The correct look-up is :

    childSymbol.typeRef.asType match {
        case '[f] => Expr.summon[BarOf[f]]
    }
    

    instead of :

    A.memberType(childSymbol).asType match {
        case '[f] => Expr.summon[BarOf[f]]
    }
    

    Complete Functional Code :

    package bar
    import scala.quoted.*
    
    
    trait Bar:
      def children : List[Bar]
    
    trait BarOf[A] extends Bar
    
    object BarOf :
      inline def derived[A]: BarOf[A] =  ${derivedImpl[A]}
    
      def derivedImpl[A:Type](using Quotes):Expr[BarOf[A]] =
        import quotes.reflect._
        val A           = TypeRepr.of[A]
        val mySymbol    = A.typeSymbol
        val terms       = mySymbol.children
    
        def childExpr(childSymbol:Symbol) =
          childSymbol.typeRef.asType match {
            case '[f] => Expr.summon[BarOf[f]]
          } getOrElse {
            report.errorAndAbort(s"ChildType ${childSymbol} of ${mySymbol} does not derive BarOf")
          }
    
        val termsExpr =  Expr.ofList(terms.map(childExpr))
        '{  new BarOf[A]
            {   def children = ${termsExpr}
                override def toString = ${Expr(s"BarOf[${mySymbol.fullName}]")}
            }
         }