Search code examples
haskellreflectiontypeclassinstancesautomatic-differentiation

AD Reflection - How does it work?


I have seen the ad package and i understand how it does automatic differentiation by providing a different instance of the class Floating and then implementing the rules of derivatives.

But in the example

Prelude Debug.SimpleReflect Numeric.AD> diff atanh x
recip (1 - x * x) * 1

We see that it can represent functions as ASTs and show them as a string with variable names.

I wonder how they did that, because when i write:

f :: Floating a => a -> a
f x = x^2

No matter what instance I provide, i will get a function f :: Something -> Something and not a representation like f :: AST, or f :: String

The instance cannot "know" what the parameters are.

How they are able to do it ?


Solution

  • It has nothing to do with the AD package, actually, and everything to do with the x in diff atanh x.

    To see this, let's define our own AST type

    data AST = AST :+ AST
             | AST :* AST
             | AST :- AST
             | Negate AST
             | Abs AST
             | Signum AST
             | FromInteger Integer
             | Variable String
    

    We can define a Num instance for this type

    instance Num (AST) where
      (+) = (:+)
      (*) = (:*)
      (-) = (:-)
      negate = Negate
      abs = Abs
      signum = Signum
      fromInteger = FromInteger
    

    And a Show instance

    instance Show (AST) where
      showsPrec p (a :+ b) = showParen (p > 6) (showsPrec 6 a . showString " + " . showsPrec 6 b)
      showsPrec p (a :* b) = showParen (p > 7) (showsPrec 7 a . showString " * " . showsPrec 7 b)
      showsPrec p (a :- b) = showParen (p > 6) (showsPrec 6 a . showString " - " . showsPrec 7 b)
      showsPrec p (Negate a) = showParen (p >= 10) (showString "negate " . showsPrec 10 a)
      showsPrec p (Abs a) = showParen (p >= 10) (showString "abs " . showsPrec 10 a)
      showsPrec p (Signum a) = showParen (p >= 10) (showString "signum " . showsPrec 10 a)
      showsPrec p (FromInteger n) = showsPrec p n
      showsPrec _ (Variable v) = showString v
    

    So now if we define a function:

    f :: Num a => a -> a
    f a = a ^ 2
    

    and an AST variable:

    x :: AST
    x = Variable "x"
    

    We can run the function to produce either integer values or AST values:

    λ f 5
    25
    λ f x
    x * x
    

    If we wanted to be able to use our AST type with your function f :: Floating a => a -> a; f x = x^2, we'd need to extend its definition to allow us to implement Floating (AST).