Search code examples
haskelltypechecking

Why does type inference fail for a polymorphic function applied to different inputs withing the same function


I'm making an interpreter for a subset of C++. The interpreter is written in Haskell.

My eval function for expressions returns a new environment, and a value. I encode the values as a new type called Val. Minimal example:

data Val = I Integer | D Double

To evaluate arithmetic expressions, I want to create general function which applies a polymorphic function such as (+) or (*) to the numbers wrapped inside the Val constructors.

I want to have a function like this:

-- calculate :: Num a => (a -> a -> a) -> Val -> Val -> Val
calculate f (I i1) (I i2) = I (f i1 i2)
calculate f (D d1) (D d2) = D (f d1 d2)

This gives the following error:

tmp/example.hs:4:32: error:
    • Couldn't match expected type ‘Double’ with actual type ‘Integer’
    • In the first argument of ‘D’, namely ‘(f d1 d2)’
      In the expression: D (f d1 d2)
      In an equation for ‘calculate’:
          calculate f (D d1) (D d2) = D (f d1 d2)
  |
4 | calculate f (D d1) (D d2) = D (f d1 d2)
  |                                ^^^^^^^

tmp/example.hs:4:34: error:
    • Couldn't match expected type ‘Integer’ with actual type ‘Double’
    • In the first argument of ‘f’, namely ‘d1’
      In the first argument of ‘D’, namely ‘(f d1 d2)’
      In the expression: D (f d1 d2)
  |
4 | calculate f (D d1) (D d2) = D (f d1 d2)
  |                                  ^^

tmp/example.hs:4:37: error:
    • Couldn't match expected type ‘Integer’ with actual type ‘Double’
    • In the second argument of ‘f’, namely ‘d2’
      In the first argument of ‘D’, namely ‘(f d1 d2)’
      In the expression: D (f d1 d2)
  |
4 | calculate f (D d1) (D d2) = D (f d1 d2)
  |                                     ^^

I can't wrap my head around this. I have two questions:

  1. Why does this program fail to type check?
  2. How could I implement calculate correctly?

I'm only vaguely familiar with universally quantified types, so if that is part of the problem, please explain gently.


Solution

  • You've correctly identified that you need universal quantification. In fact, you already have universal quantification – your signature, like any polymorphic signature, is basically shorthand for

    {-# LANGUAGE ExplicitForall, UnicodeSyntax #-}
    calculate :: ∀ a . Num a => (a -> a -> a) -> Val -> Val -> Val
    

    meaning: whenever someone wants to use this function, they get to choose some type to put in for a upfront. For example, they could pick Int, then the function would get specialised to

    calculate :: (Int -> Int -> Int) -> Val -> Val -> Val
    

    and that's then used at runtime.

    But that's no use for you, because you're going to need to use this function for different number types. No single specialisation is going to cover all of them.

    The solution: delay the choosing of a type. That's accomplished by putting the universal quantor (you can also write it forall) inside the combinator-function part of the signature:

    {-# LANGUAGE Rank2Types #-}
    calculate :: (∀ a . Num a => a -> a -> a) -> Val -> Val -> Val
    

    which will typecheck. It does require the -XRank2Types extension because this is a rather more complex beast: now you can't simply picture the polymorphic function as a family of specialisations with concrete monomorphic types, but instead the funtion needs to be ready to instantiate, at runtime, the supplied function with whatever types happen to occur in the data structure.

    I.e., it needs to pass an additional argument to the function: a “dictionary” containing the Num class methods. The underlying implementation GHC generates is something like this:

    data NumDict a = NumDict {
            addition :: a -> a -> a
          , subtraction :: a -> a -> a
          , multiplication :: a -> a -> a
          , abs :: a -> a
          ...
          }
    
    calculate' :: (∀ a . NumDict a -> a -> a -> a) -> Val -> Val -> Val
    calculate' f (I i1) (I i2) = I (f ndict i1 i2)
     where ndict = NumDict ((+) :: Integer -> Integer -> Integer)
                           ((-) :: Integer -> Integer -> Integer)
                           ...