Search code examples
haskelltype-mismatchtype-systemsautomatic-differentiation

Why is a function type required to be "wrapped" for the type checker to be satisfied?


The following program type-checks:

{-# LANGUAGE RankNTypes #-}

import Numeric.AD (grad)

newtype Fun = Fun (forall a. Num a => [a] -> a)

test1 [u, v] = (v - (u * u * u))
test2 [u, v] = ((u * u) + (v * v) - 1)

main = print $ fmap (\(Fun f) -> grad f [1,1]) [Fun test1, Fun test2]

But this program fails:

main = print $ fmap (\f -> grad f [1,1]) [test1, test2]

With the type error:

Grad.hs:13:33: error:
    • Couldn't match type ‘Integer’
                     with ‘Numeric.AD.Internal.Reverse.Reverse s Integer’
      Expected type: [Numeric.AD.Internal.Reverse.Reverse s Integer]
                     -> Numeric.AD.Internal.Reverse.Reverse s Integer
        Actual type: [Integer] -> Integer
    • In the first argument of ‘grad’, namely ‘f’
      In the expression: grad f [1, 1]
      In the first argument of ‘fmap’, namely ‘(\ f -> grad f [1, 1])’

Intuitively, the latter program looks correct. After all, the following, seemingly equivalent program does work:

main = print $ [grad test1 [1,1], grad test2 [1,1]]

It looks like a limitation in GHC's type system. I would like to know what causes the failure, why this limitation exists, and any possible workarounds besides wrapping the function (per Fun above).

(Note: this is not caused by the monomorphism restriction; compiling with NoMonomorphismRestriction does not help.)


Solution

  • This is an issue with GHC's type system. It is really GHC's type system by the way; the original type system for Haskell/ML like languages don't support higher rank polymorphism, let alone impredicative polymorphism which is what we're using here.

    The issue is that in order to type check this we need to support foralls at any position in a type. Not only bunched all the way at the front of the type (the normal restriction which allows for type inference). Once you leave this area type inference becomes undecidable in general (for rank n polymorphism and beyond). In our case, the type of [test1, test2] would need to be [forall a. Num a => a -> a] which is a problem considering that it doesn't fit into the scheme discussed above. It would require us to use impredicative polymorphism, so called because a ranges over types with foralls in them and so a could be replaced with the type in which it's being used.

    So, therefore there's going to be some cases that misbehave just because the problem is not fully solvable. GHC does have some support for rank n polymorphism and a bit of support for impredicative polymorphism but it's generally better to just use newtype wrappers to get reliable behavior. To the best of my knowledge, GHC also discourages using this feature precisely because it's so hard to figure out exactly what the type inference algorithm will handle.

    In summary, math says that there will be flaky cases and newtype wrappers are the best, if somewhat dissatisfying way, to cope with it.