Search code examples
scalapartialfunctionpartial-functions

Scala function composition totalFn(partialFn(totalFn(x)))


I was trying to compose three functions with only the middle one being a PartialFunction. I would expect the resulting type to be PartialFunction as well.

Example:

val mod10: Int => Int = _ % 10
val inverse: PartialFunction[Int, Double] = { case n if n != 0 => 1.0 / n }
val triple: Double => Double = _ * 3

val calc: Int => Double = mod10 andThen inverse andThen triple

However, calc is not defined on the whole of its domain. It will throw MatchError for every number divisible by 10.

What is the reason for returning a total function when at least one of the functions in the composition is partial?

Another example where composition of partial functions results in another partial function with incorrect domain conditions:

val inverse: PartialFunction[Double, Double] = { case n if n != 0 => 1.0 / n }
val arcSin: PartialFunction[Double, Double] = { 
   case n if math.abs(n) <= 1 => math.asin(n)
}

val calc: PartialFunction[Double, Double] = inverse andThen arcSin

I would expect the domain of calc to be (-Infinity, -1] union [1, Infinity) but calling calc.lift(0.5) will throw a MathError instead of returning None because the input is within the first function's domain.

Thanks, Norbert


Solution

  • Example 1: What is the reason for returning a total function when at least one of the functions in the composition is partial?

    It's because the first function in your first example is a total function (Function1) and its andThen method returns a Function1 regardless of whether the second function is total or partial:

    def andThen[A](g: (R) => A): (T1) => A
    

    My guess is that the Scala language design team prefers a more generalized returned value since PartialFunction is a subclass of Function and would rather let users derive specialized code as needed.

    Example 2: calling calc.lift(0.5) will throw a MathError instead of returning None

    From the PartialFunction API doc, composing two partial functions via andThen will return a partial function with the same domain as the first partial function:

     def andThen[C](k: (B) => C): PartialFunction[A, C]
    

    Thus, the resultant composed function disregards the fact that inverse(0.5) (i.e. 2.0) is outside the domain of the second partial function arcSin.


    So, when composing a function (total or partial) with a partial function using andThen, how can we make it return a partial function with proper domain?

    Similar to what's demonstrated in this SO Q&A, one can enhance andThen via a couple of implicit classes to restrict the domain of the resultant composed function to a subset of the first function's domain that return values within the partial function's domain:

    object ComposeFcnOps {
      implicit class TotalCompose[A, B](f: Function[A, B]) {
        def andThenPartial[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
          Function.unlift(x => Option(f(x)).flatMap(that.lift))
      }
    
      implicit class PartialCompose[A, B](pf: PartialFunction[A, B]) {
        def andThenPartial[C](that: PartialFunction[B, C]): PartialFunction[A, C] =
          Function.unlift(x => pf.lift(x).flatMap(that.lift))
      }
    }
    

    Testing with the example functions:

    import ComposeFcnOps._
    
    val mod10: Int => Int = _ % 10
    val inverse1: PartialFunction[Int, Double] = { case n if n != 0 => 1.0 / n }
    val triple: Double => Double = _ * 3
    
    val calc1 = mod10 andThenPartial inverse1 andThen triple
    // calc1: PartialFunction[Int,Double] = <function1>
    
    calc1.isDefinedAt(0)
    // res1: Boolean = false
    
    val inverse2: PartialFunction[Double, Double] = { case n if n != 0 => 1.0 / n }
    val arcSin: PartialFunction[Double, Double] = { 
       case n if math.abs(n) <= 1 => math.asin(n)
    }
    
    val calc2 = inverse2 andThenPartial arcSin
    // calc2: PartialFunction[Double,Double] = <function1>
    
    calc2.isDefinedAt(0.5)
    // res2: Boolean = false
    
    calc2.lift(0.5)
    // res3: Option[Double] = None