Search code examples
scalagenericstype-parameter

Using same type parameter as argument type and parameter type with match expression


I get errors by compiling following example code.

abstract class Base
case class A(i: Int)    extends Base
case class B(s: String) extends Base

class Transform {
  def func[T <: Base](arg: T): T = arg match {
    case A(i) => A(i)
    case B(s) => B(s)
  }
}

errors are

Example.scala:9: error: type mismatch;
 found   : A
 required: T
    case A(i) => A(i)
                  ^
Example.scala:10: error: type mismatch;
 found   : B
 required: T
    case B(s) => B(s)
                  ^
two errors found

These errors are reasonable.
To avoid this, I need to put asInstanceOf[T] behind instantiation like A(i).asInstanceOf[T]. However, it is annoying to do like that for all return value if there are a lot of match case patterns.

In addition, I want to use Transform class as parent class and override func() to execute specific operation like below code.

class ExtTransform extends Transform {
  override def func[T <: Base](arg: T): T = arg match {
    case A(i) => A(i + 1)
    case _    => super.func(arg)
  }
}

Are there better ways or some trick?


Solution

  • To avoid this, I need to put asInstanceOf[T] behind instantiation like A(i).asInstanceOf[T]. However, it is annoying to do like that for all return value if there are a lot of match case patterns.

    Well, that problem is an easy one: put it in one place at the end of the match instead of every branch.

    override def func[T <: Base](arg: T): T = (arg match {
      case A(i) => A(i)
      case B(s) => B(s)
    }).asInstanceOf[T]
    

    But please note your design is inherently unsafe because there are subtypes of Base other than Base, A, and B: singleton types (a.type), compound types (A with SomeTrait), Null... and any of them can be used as T. It may be better just to have overloads:

    class Transform {
      def func(arg: Base): Base = arg match {
        case arg: A => func(arg)
        case arg: B => func(arg)
      }
    
      def func(arg: A): A = arg
      def func(arg: B): B = arg
    }
    
    class ExtTransform extends Transform {
      override def func(arg: A): A = A(arg.i + 1)
    }