Search code examples
scalalambdatype-parameter

Specify arithmetic operation before backing type is known?


I'm trying to write an interface through which users should be able to pass arbitrary arithmetic functions independent of numeric type, and then helper code would bind them to appropriate backing Numeric used by the system. (The need for this is complicated, but it's a DSL design thing and the numeric type which will ultimately be used is not known at compile time, though it can be restricted to a small set.)

So the goal is to get something like:

case class numericOp(
  intVersion : (Int, Int) -> Int,
  floatVersion : (Float, Float) -> Float
)
def wrap(f : ???) : NumericOp = NumericOp(f[Int], f[Float])

//use site
wrap( (x, y) => x + y)

//use in engine
def applyArithmetic(arg1 : Any, arg2 : Any, op : NumericOp) : Any{
  (arg1, arg2) match {
    case (a : Int, b : Int) => op.intVersion(a, b)
    case (a : Float, b : Float) => op.floatVersion(a, b)
    case _ => throw new Exception("Invalid types") //shouldn't actually be reachable

}

The above - which obviously doesn't compile - would have wrap take a generic lambda with an un-applied type parameter, and then apply the backing types to it to create real functions.

I understand why this doesn't work - type parameter application would have to happen before runtime, i.e., before the lambda even exists. That said, the core of what I'm trying to do - describe arithmetic logic as data, then create specific implementations later once a type is available - seems like it should be possible. (My reasoning is that I could describe a complete DSL of all operations under Numeric, have users specify their math in that DSL, and then create implementations for Int/Float/Whatever later.)

Is there a way to make this work?

I very much want to keep the use site concise, and more importantly, abstracted from the backing types.

To give some more explicit detail and requirements:

  1. First, the user specifies their arithmetic operation in some type-agnostic way. It's fine for them to know they're dealing in some Numeric type, but the should not have to know the specific type, or even the number of possible types. I don't know what format this input could take, thus the ??? in the code.
  2. Their specification - call this f - is passed to a handler function (wrap, above). This function can know the set of possible backing types, and produces an intermediate structure to store any needed type-specific data (numericOp, above)
  3. The intermediate structure gets used in an interpreter that operates over dynamic Any values (type checking is done on the input DSL code, and casting is done when needed). This is done in applyArithmetic, above.

To take a more temporal view of it:

  1. The user specifies some arithmetic logic, passing it into the wrap function
  2. The system maintainer re-arranges the set of possible backing numeric types, adding/removing them without telling the user
  3. The code is compiled
  4. During runtime, the specific values are passed to an applyArithmetic function as Anys along with whatever intermediate type wrap output.

Solution

  • If you use something like Numeric[Int] compiler will resolve it in compile time to some value and then use it. So it would be hardcoded as you suspect.

    def func(arg1: Int, arg2: Int): Int = {
      Numeric[Int].plus(arg1, arg2)
    }
    

    The same would not be true though if you did:

    def func[T: Numeric](arg1: T, arg2: T): T = {
      Numeric[T].plus(arg1, arg2)
    }
    

    Why?

    def func[T: Numeric](arg1: T, arg2: T): T = ...
    

    is syntactic sugar for

    def func[T](arg1: T, arg2: T)(implicit generatedName: Numeric[T]): T = ...
    

    Therefore

    def numericOp[T: Numeric](arg1: T, arg2: T): T = {
      Numeric[T].plus(arg1, arg2)
    }
    
    numericOp(1, 2)
    numericOp(1L, 2L)
    numericOp(1.0, 2.0)
    

    is syntactic sugar for

    def numericOp[T](arg1: T, arg2: T)(implicit genName: Numeric[T]): T = {
      genName.plus(arg1, arg2)
    }
    
    // Integral is subtype of Numeric and implicit finds
    // these implementations in Numeric companion object
    numericOp(1, 2)(Numeric.IntIsIntegral)
    numericOp(1L, 2L)(Numeric.LongIsIntegral)
    numericOp(1.0, 2.0)(Numeric.DoubleIsIntegral)
    

    If you consistently used implicits with a type class - and often also extension methods to make it more readable:

    // Scala 2
    implicit class NumericOps[T](private val value: T) extends AnyVal {
      
      def +(another: T)(implicit T: Numeric[T]): T = T.plus(value, another)
    }
    
    // Scala 3
    extension [T](value: T)
      def +(another: T)(implicit T: Numeric[T]): T = T.plus(value, another)
    

    this parametrized interface with implementation resolved for the type - deferring the specification of T (and resolving implicit to value of known type) it's an example of something called tagless final.

    Alternative is to use free algebra:

    sealed trait NumericExpr[T] {
    
      import NumericExpr._
    
      def + (another: NumericExpr[T]): NumericExpr[T] =
        Plus(this, another)
      def - (another: NumericExpr[T]): NumericExpr[T] =
        Minus(this, another)
    
      def run(plus: (T, T) => T)(minus: (T, T) => T): T = this match {
        case Value(t) => t
        case Plus(a, b) => plus(run(a), run(b))
        case Minus(a, b) => minus(run(a), run(b))
      }
    }
    object NumericExpr {
      case class Value[T](number: T) extends NumericExpr
      case class Plus[T](a: NumericExpr[T], b: a: NumericExpr[T]) extends NumericExpr
      case class Minus[T](a: NumericExpr[T], b: a: NumericExpr[T]) extends NumericExpr
    
      def wrap[T](t: T): NumericExpr[T] = Value(t)
    }
    

    The way you use it is to wrap all values in this wrapper type, and let this wrapper type "record" operations (without knowing the exact implementation) which you would later "replay" providing implementation as the last step.

    // generic code works with free algebra
    def numericOp[T](arg1: NumericExpr[T], arg2: NumericExpr[T]): NumericExpr[T] = {
      arg1 + arg2
    }
    
    // specific code lifts values into free algebra and later provide implementation
    numericOp(NumericExpr.wrap(1), NumericExpr.warp(2)).run(_ + _)(_ - _)
    

    EDIT:

    If you want tagless final-like approach there are 2 options:

    Scala 2:

    // There are no polymorphic methods in Scala 2
    // and neither there are funtions with implicits
    // - we you have to attach such method to an interface
    trait numericOps {
      def apply[T: Numeric](t1: T, t2: T): T
    }
    
    val plus: numericOps = new numericOps {
      def apply[T: Numeric](t1: T, t2: T): T = Numeric[T].plus(t1, t2)
    }
    
    plus2(1, 2)
    

    Scala 3:

    // Has both polymorphic function types as well
    // as context functions
    type numericOps = [T] => (T, T) => Numeric[T] ?=> T
    
    val plus: numericOps = [T] => (t1: T, t2: T) => (num: Numeric[T]) ?=> num.plus(t1, t2)
    
    plus(1, 2)
    

    See demo.