Search code examples
scalaclassreflectionsubclass

Recursively obtain subclass(es) of a class


Main Question: How can you print all the subclasses (recursively) of a class in Scala?

Context:

I am using the CDE library (https://github.com/chipsalliance/cde) that is used in a variety of downstream projects to parameterize classes. As a result of using the library, a large amount of case classes are created extending a class called Config (across multiple Scala files), like so:

class Cfg0 extends Config(...)

class Cfg1 extends Cfg0

class Cfg2 extends Cfg1

Our users normally manually call out the "Config" object names in Makefiles, etc and I would like to print the names of all the classes so that these classes are easier to discover (instead of grepping).

Definition of Config class: https://github.com/chipsalliance/cde/blob/384c06b8d45c8184ca2f3fba2f8e78f79d2c1b51/cde/src/chipsalliance/rocketchip/config.scala#L151

I've looked at:

How do I use Scala reflection to find all subclasses of a trait (without using third-party tools)? https://gist.github.com/longshorej/1a0a2cf50de8e6ff101c

But they looked like they were oriented around case classes/traits.


Solution

  • I slightly updated my answer at How do I use Scala reflection to find all subclasses of a trait (without using third-party tools)?

    Out of the options listed there, for example at runtime you can use classpath scanners like Reflections, ClassGraph, Burningwave, ...

    // libraryDependencies += "org.reflections" % "reflections" % "0.10.2"
    import org.reflections.Reflections
    import org.reflections.scanners.Scanners.SubTypes
    import scala.jdk.CollectionConverters._
    
    val reflections = new Reflections()
    reflections.get(SubTypes.of(classOf[Config]).asClass()).asScala
    // Set(class Cfg0, class Cfg1, class Cfg2)
    

    Or at compile time you can try Shapeless and macro-based type class KnownSubclasses from Shapeless - How to derive LabelledGeneric for Coproduct (traversing ASTs in all compilation units of a current compiler run, not using scala-reflect .knownDirectSubclasses)

    // libraryDependencies += "com.chuusai" %% "shapeless" % "2.3.10"
    import shapeless.{Coproduct, HList, Poly0, Poly1, Typeable}
    import shapeless.ops.coproduct.ToHList
    import shapeless.ops.hlist.{FillWith, Mapper, ToList}
    
    object typeablePoly extends Poly1 {
      implicit def cse[A](implicit typeable: Typeable[A]): Case.Aux[A, String] =
        at(_ => typeable.describe)
    }
    
    object nullPoly extends Poly0 {
      implicit def cse[A]: Case0[A] = at(null.asInstanceOf[A])
    }
    
    def getSubclasses[A] = new PartiallyAppliedGetSubclasses[A]
    
    class PartiallyAppliedGetSubclasses[A] {
      def apply[C <: Coproduct, L <: HList, L1 <: HList]()(implicit
        knownSubclasses: KnownSubclasses.Aux[A, C],
        toHList: ToHList.Aux[C, L],
        mapper: Mapper.Aux[typeablePoly.type, L, L1],
        fillWith: FillWith[nullPoly.type, L],
        toList: ToList[L1, String]
      ): List[String] =
        toList(mapper(fillWith()))
    }
    
    getSubclasses[Config]()
    // List(Cfg0, Cfg1, Cfg2)
    

    For both above options it's irrelevant whether traits/classes are sealed or not, whether classes are case classes or not.