Search code examples

Using same type parameter as argument type and parameter type with match expression

I get errors by compiling following example code.

abstract class Base
case class A(i: Int)    extends Base
case class B(s: String) extends Base

class Transform {
  def func[T <: Base](arg: T): T = arg match {
    case A(i) => A(i)
    case B(s) => B(s)

errors are

Example.scala:9: error: type mismatch;
 found   : A
 required: T
    case A(i) => A(i)
Example.scala:10: error: type mismatch;
 found   : B
 required: T
    case B(s) => B(s)
two errors found

These errors are reasonable.
To avoid this, I need to put asInstanceOf[T] behind instantiation like A(i).asInstanceOf[T]. However, it is annoying to do like that for all return value if there are a lot of match case patterns.

In addition, I want to use Transform class as parent class and override func() to execute specific operation like below code.

class ExtTransform extends Transform {
  override def func[T <: Base](arg: T): T = arg match {
    case A(i) => A(i + 1)
    case _    => super.func(arg)

Are there better ways or some trick?


  • To avoid this, I need to put asInstanceOf[T] behind instantiation like A(i).asInstanceOf[T]. However, it is annoying to do like that for all return value if there are a lot of match case patterns.

    Well, that problem is an easy one: put it in one place at the end of the match instead of every branch.

    override def func[T <: Base](arg: T): T = (arg match {
      case A(i) => A(i)
      case B(s) => B(s)

    But please note your design is inherently unsafe because there are subtypes of Base other than Base, A, and B: singleton types (a.type), compound types (A with SomeTrait), Null... and any of them can be used as T. It may be better just to have overloads:

    class Transform {
      def func(arg: Base): Base = arg match {
        case arg: A => func(arg)
        case arg: B => func(arg)
      def func(arg: A): A = arg
      def func(arg: B): B = arg
    class ExtTransform extends Transform {
      override def func(arg: A): A = A(arg.i + 1)