Search code examples
genericsscalanumerical

Scala: Whats the best way to do numeric operations in generic classes?


In Scala, I'd like to be able to write generic classes which use operators like >, /, * etc, but I don't see how to constrain T such that this will work.

I looked into constraining T with Ordered[T], but that doesn't seem to work since only RichXXX (e.g. RichInt) extend it, not Int etc. I also saw Numeric[T], is this only available in Scala 2.8?

Here is a specific example:

class MaxOfList[T](list: List[T] ) {
  def max = {
    val seed: Option[T] = None

    list
      .map( t => Some(t))
      // Get the max      
      .foldLeft(seed)((i,m) => getMax(i,m) )
  }

  private def getMax(x: Option[T], y: Option[T]) = {
    if ( x.isDefined && y.isDefined )
      if ( x > y ) x else y
    else if ( x.isDefined )
      x
    else
      y
  }
}

This class won't compile, because there are many Ts which don't support > etc.

Thoughts?

For now I've used a MixIn trait to get around this:

/** Defines a trait that can get the max of two generic values
 */
trait MaxFunction[T] {
  def getMax(x:T, y:T): T
}

/** An implementation of MaxFunction for Int
 */
trait IntMaxFunction extends MaxFunction[Int] {
  def getMax(x: Int, y: Int) = x.max(y)
} 

/** An implementation of MaxFunction for Double
 */
trait DoubleMaxFunction extends MaxFunction[Double] {
  def getMax(x: Double, y: Double) = x.max(y)
} 

Which if we alter the original class can be mixed in at instantiation time.

P.S. Mitch, inspired by your re-write of getMax, here is another:

  private def getMax(xOption: Option[T], yOption: Option[T]): Option[T] = (xOption,yOption) match {
    case (Some(x),Some(y)) => if ( x > y ) xOption else yOption
    case (Some(x), _) => xOption
    case _ => yOption
  }

Solution

  • You can use View Bounds.

    In short, def foo[T <% U](t: T) is a function that will take any T that either is or can be implicitly converted to a U. Since an Int can be converted to a RichInt (which contains your desired method), this is an excellent example of usage.

    class MaxOfList[T <% Ordered[T]](list: List[T] ) {
      def max = {
        val seed: Option[T] = None
        list.foldLeft(seed)(getMax(_,_))
      }
    
      private def getMax(xOption: Option[T], y: T) = (xOption, y) match {
        case (Some(x), y) if ( x > y ) => xOption
        case (_, y) => Some(y)
      }
    }
    

    PS - I rewrote your getMax(...) method to compare the values instead of the options themselves, and used pattern matching instead of isDefined(...)

    PPS - Scala 2.8 will have a Numeric trait that may be of use. http://article.gmane.org/gmane.comp.lang.scala/16608


    Addendum

    Just for giggles, here's the super-compact version that eliminates the getMax method altogether:

    class MaxOfList[T <% Ordered[T]](list: List[T] ) {
      def max = list.foldLeft(None: Option[T]) {
          case (Some(x), y) if ( x > y ) => Some(x)
          case (_, y) => Some(y)
      }
    }
    

    Yet Another Addendum

    This version would be more efficient for large lists... avoids creation of Some(x) for each element:

    class MaxOfList[T <% Ordered[T]](list: List[T] ) {
      def max = {
        if (list.isEmpty) None
        else Some(list.reduceLeft((a,b) => if (a > b) a else b))
      }
    }
    

    Last One, I Promise!

    At this point, you can just ditch the class and use a function:

      def max[T <% Ordered[T]](i: Iterable[T]) = {
        if (i.isEmpty) None
        else Some(i.reduceLeft((a,b) => if (a > b) a else b))
      }