Search code examples
scalacovariancecontravariance

Scala covariance assignment to super type


I am trying to figure out scala covariance and contravariance. May be am confused between two different concepts. Working on this code below:

sealed trait Algorithm[ +T <: Model, P <: Model, R <: AnyVal] {
  def name: String
  def train(trainingData: DenseMatrix[Double]): T
  def predict(row: DenseVector[R], mlModel : P): R
}

Then i have two algorithm types declared as:

case class LibLinear() extends Algorithm[Active_Linear, Active_Linear, Double] {
  override val name = "libLinear"
  override def train(trainingData: DenseMatrix[Double]): Active_Linear = {
    ........
  }
  override def predict(row: DenseVector[Double], model: Active_Linear): Double = {
    ..........
  }
}


case class SVM() extends Algorithm[Volume_SVM, Volume_SVM, Double] {
  override val name = "libSVM"
  override def train(trainingData: DenseMatrix[Double]): Volume_SVM = {
    ..........
  }
  override def predict(row: DenseVector[Double], model: Volume_SVM): Double = {
    ...........
  }
}

Where both Active_Linear and Volume_SVM are sub types of Model.

Now i cannot do this:

val algorithm: Algorithm[Model, Model, Double] =  SVM()

SVM is a sub type of Algorithm and Volume_SVM is a sub type of Model. And We declare Algorithm with covariant and contravariant notations.


Solution

  • That's because Algorithm is covariant only when T is concerned. P and R are defined as invariant; you have to precede each of them with + or - as you need, then modify the code accordingly. Based on your final assignment, I've made some assumptions and that's what I've come up with:

    sealed trait Algorithm[ +T <: Model, +P <: Model, +R <: AnyVal] {
      def name: String
      def train(trainingData: DenseMatrix[Double]): T
      def predict[U >: R, V >: P](row: DenseVector[U], mlModel : V): U
    }
    
    case class LibLinear() extends Algorithm[Active_Linear, Active_Linear, Double] {
      override val name = "libLinear"
      override def train(trainingData: DenseMatrix[Double]): Active_Linear = {
        ...
      }
    
      override def predict[U >: Double, V >: Active_Linear](row: DenseVector[U], model: V): U = {
        ...
      }
    }
    
    
    case class SVM() extends Algorithm[Volume_SVM, Volume_SVM, Double] {
      override val name = "libSVM"
      override def train(trainingData: DenseMatrix[Double]): Volume_SVM = {
          ...
      }
      override def predict[U >: Double, V >: Volume_SVM](row: DenseVector[U], model: V): U = {
          ...
      }
    }
    

    Your last assignment then works fine.