Search code examples
haskellautomatic-differentiation

Haskell can't deduce type equality


I have the following code, which does not compile:

  import Numeric.AD

  data Trainable a b = forall n . Floating n =>  Trainable ([n] -> a -> b) (a -> b -> [n] -> n) 

  trainSgdFull :: (Floating n, Ord n) => Trainable a b -> [n] -> a -> b -> [[n]]
  trainSgdFull (Trainable _ cost) init input target =  gradientDescent (cost input target) init

I want to use the Trainable type to represent machine learning systems trainable by gradient descent. The first arguemnt would be the transfer function, and the sencond would be the cost function, a is the input type, and b is the output/target type, and the list contains the learnable parameters. The compiler complains this:

 src/MachineLearning/Training.hs:12:73:
Could not deduce (n1 ~ ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n)
from the context (Floating n, Ord n)
  bound by the type signature for
             trainSgdFull :: (Floating n, Ord n) =>
                             Trainable a b -> [n] -> a -> b -> [[n]]
  at src/MachineLearning/Training.hs:12:3-95
or from (Floating n1)
  bound by a pattern with constructor
             Trainable :: forall a b n.
                          Floating n =>
                          ([n] -> a -> b) -> (a -> b -> [n] -> n) -> Trainable a b,
           in an equation for `trainSgdFull'
  at src/MachineLearning/Training.hs:12:17-32
or from (Numeric.AD.Internal.Classes.Mode s)
  bound by a type expected by the context:
             Numeric.AD.Internal.Classes.Mode s =>
             [ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n]
             -> ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n
  at src/MachineLearning/Training.hs:12:56-95
  `n1' is a rigid type variable bound by
       a pattern with constructor
         Trainable :: forall a b n.
                      Floating n =>
                      ([n] -> a -> b) -> (a -> b -> [n] -> n) -> Trainable a b,
       in an equation for `trainSgdFull'
       at src/MachineLearning/Training.hs:12:17
Expected type: [ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n1]
               -> ad-3.3.1.1:Numeric.AD.Internal.Types.AD s n1
  Actual type: [n] -> n
In the return type of a call of `cost'
In the first argument of `gradientDescent', namely
  `(cost input target)'

Is the basic concept right? If it is, how could I make the code compile?


Solution

  • The problem is that

    data Trainable a b = forall n . Floating n =>  Trainable ([n] -> a -> b) (a -> b -> [n] -> n)
    

    means that in

    Trainable transfer cost
    

    the type n used is lost. All that is known is that there is some type Guessme with a Floating instance such that

    transfer :: [Guessme] -> a -> b
    cost :: a -> b -> [Guessme] -> Guessme
    

    You can build Trainables with functions that only work for Complex Float, or only for Double, or ...

    But in

    trainSgdFull :: (Floating n, Ord n) => Trainable a b -> [n] -> a -> b -> [[n]]
    trainSgdFull (Trainable _ cost) init input target =  gradientDescent (cost input target) init
    

    you are trying to use cost with whatever Floating type is supplied as an argument.

    The Trainable was built to work with type n0, the user supplies type n1, and those may or may not be the same. Thus the compiler can't deduce they are the same.

    If you don't want to make n a type parameter of Trainable, you need to make it wrap polymorphic functions that work with every Floating type the caller supplies

    data Trainable a b
        = Trainable (forall n. Floating n => [n] -> a -> b)
                    (forall n. Floating n => a -> b -> [n] -> n)
    

    (needs Rank2Types, or, since that is in the process of being deprecated, RankNTypes).