Search code examples
scalascala-3

How to constraint a type parameter to be an `enum` and get the number of cases of the enum represented by it, in scala3?


I want to create an enum map collection similar to java.util.EnumMap but in scala 3.4. Something like this:

import scala.reflect.ClassTag

class EnumMap[K <: scala.reflect.Enum, V](using ctV: ClassTag[V]) {
    val capacity: Int = ??? // How to get the number of cases of the enum represented by K?

    private val backingArray: Array[V] = ctV.newArray(capacity);

    def get(k: K): V = backingArray(k.ordinal);

    def put(k: K, v: V): Unit = backingArray(k.ordinal) = v;
}

And I faced two problems:

  1. How to constraint the type parameter K to accept only enum enumerations and not their cases also? Such that it accepts Color but not Color.red.type.
  2. How to obtain the number of cases corresponding to the enumeration represented by K without using runtime operations. If I am correct, the compiler has all the information needed.

Solution

  • Something like this?

    import scala.reflect.ClassTag
    
    class EnumMap[K: EnumMap.Ordinals, V: ClassTag] {
      val capacity: Int = summon[EnumMap.Ordinals[K]].size
    
      private val backingArray: Array[V] =
        summon[ClassTag[V]].newArray(capacity)
    
      def get(k: K): V = 
        backingArray(summon[EnumMap.Ordinals[K]].ordinal(k))
    
      def put(k: K, v: V): Unit =
        backingArray(summon[EnumMap.Ordinals[K]].ordinal(k)) = v
    }
    object EnumMap {
    
      trait Ordinals[K] {
    
        def size: Int
    
        def ordinal(key: K): Int
      }
    
      // SumOf makes sure it's an enum...
      inline given [K](using k: scala.deriving.Mirror.SumOf[K]): Ordinals[K] =
        new Ordinals[K] {
    
          // ...and ValueOf ensures that all its its valies are
          // parameterless (case Baz(a: Int) would not work)
          def size: Int = scala.compiletime
            .summonAll[Tuple.Map[k.MirroredElemTypes, ValueOf]]
            .productArity
    
          def ordinal(key: K): Int = k.ordinal(key)
        }
    }
    
    
    enum Foo:
      case Bar, Baz
    
    val map = new EnumMap[Foo, String]
    
    map.put(Foo.Bar, "test")
    map.get(Foo.Bar)
    

    (Scastie)

    BTW, this is rather non-idiomatic, get should return Option rather than throw, and this would a bit more work to make it work with other collections (like implementing to, some Factory[(K, V), EnumMap[K, V]]) and with for comprehension (map, flatMap, foreach). But I understand it's just a minimized example.