Search code examples
algorithmhaskelldata-structuresfunctional-programmingautomatic-differentiation

Representing a computational graph in Haskell


I'm trying to write a simple automatic differentiation package in Haskell.

What are the efficient ways to represent a type-safe (directed) computational graph in Haskell? I know that the ad package uses the "data-reify" method for that but I'm not quite familiar with it. Can anyone provide me with some insights? Thanks!


Solution

  • As Will Ness' comment indicates, the correct abstraction for AD is a category, not a graph. Unfortunately the standard Category class doesn't really do the trick, because it requires arrows between any Haskell types, but differentiation only makes sense between smooth manifolds. Most libraries don't know about manifolds and restrict it further to Euclidean vector spaces (which they represent as “vectors” or “tensors” which are just arrays). There really isn't a compelling reason to be that restrictive – any affine space will do for forward-mode AD; for reverse mode you also need a notion of the dual space to difference vectors.

    data FwAD x y = FwAD (x -> (y, Diff x -> Diff y))
    data RvAD x y = RvAD (x -> (y, DualVector (Diff y) -> DualVector (Diff x)))
    

    where the Diff x -> Diff y function must be a linear function. (You can use a dedicated arrow type for such functions, or you can just use (->) functions which happen to be linear.) The only thing that's different in reverse-mode is that the adjoint of this linear map is represented, instead of the map itself. (In a real-valued matrix implementation, the linear mapping is the Jacobian matrix and the adjoint version its transpose, but don't use matrices, they're terribly inefficient for this.)
    Neat, right? All that graph/traversal/mutation/backwards-pass nonsense that people keep talking about isn't really needed. (See Conal's paper for elaboration.)

    So, to make this useful in Haskell you need to implement the category combinators. This is pretty much exactly what I wrote the constrained-categories package for. Here's an outline instantiation for what you need:

    import qualified Prelude as Hask
    import Control.Category.Constrained.Prelude
    import Control.Arrow.Constrained
    import Data.AffineSpace
    import Data.AdditiveGroup
    import Data.VectorSpace
    
    instance Category FwAD where
      type Object FwAD a
         = (AffineSpace a, VectorSpace (Diff a), Scalar (Diff a) ~ Double)
      id = FwAD $ \x -> (x, id)
      FwAD f . FwAD g = FwAD $ \x -> case g x of
         (gx, dg) -> case f gx of
           (fgx, df) -> (fgx, df . dg)
    
    instance Cartesian FwAD where
      ...
    instance Morphism FwAD where
      ...
    instance PreArrow FwAD where
      ...
    instance WellPointed FwAD where
      ...
    

    The instances are all easy and almost unambiguous, let the compiler messages guide you (typed holes _ are enormously useful). Basically, whenever a variable of a type that's in scope is required, use it; when a variable of vector-space type that's not in scope is required, use zeroV.

    At that point you'll really have all the fundamental differentiable functions tooling in place, but to actually define such functions you would need to use point-free style with lots of ., &&& and *** combinators and hard-coded numerical primitives, which looks unconventional and rather confusing. To avoid that, you can use agent values: values that basically pretend to be simple number variables, but actually contain an entire category arrow from some fixed domain type. (This would basically be the “building up a graph” part of the exercise.) You can simply use the provided GenericAgent wrapper.

    instance HasAgent FwAD where
      type AgentVal FwAD a v = GenericAgent FwAD a v
      alg = genericAlg
      ($~) = genericAgentMap
    
    instance CartesianAgent FwAD where
      alg1to2 = genericAlg1to2
      alg2to1 = genericAlg2to1
      alg2to2 = genericAlg2to2
    
    instance PointAgent (GenericAgent FwAD) FwAD a x where
      point = genericPoint
    
    instance ( Num v, AffineSpace v, Diff v ~ v, VectorSpace v, Scalar v ~ v
             , Scalar a ~ v )
          => Num (GenericAgent FwAD a v) where
      fromInteger = point . fromInteger
      (+) = genericAgentCombine . FwAD $ \(x,y) -> (x+y, \(dx,dy) -> dx+dy)
      (*) = genericAgentCombine . FwAD $ \(x,y) -> (x*y, \(dx,dy) -> y*dx+x*dy)
      abs = genericAgentMap . FwAD $ \x -> (abs x, \dx -> if x<0 then -dx else dx)
      ...
    instance ( Fractional v, AffineSpace v, Diff v ~ v, VectorSpace v, Scalar v ~ v
             , Scalar a ~ v )
          => Fractional (GenericAgent FwAD a v) where
      ...
    instance (...) => Floating (...) where
      ...
    

    If you have all those instances complete, and perhaps a simple helper to extract the results

    evalWithGrad :: FwAD Double Double -> Double -> (Double, Double)
    evalWithGrad (FwAD f) x = case f x of
       (fx, df) -> (fx, df 1)
    

    then you can write code such as

    > evalWithGrad (alg (\x -> x^2 + x) 3)
    (12.0, 7.0)
    > evalWithGrad (alg sin 0)
    (0.0, 1.0)
    

    Under the hood, these algebraic expressions build up a composition of FwAD arrows with &&& “splitting” the data flow and *** composing in parallel, i.e. even if the input and final result are simple Double then the intermediate results will be pulled through a suitable tuple type. [That would, I guess, be the answer to your title question: the directed graph is in a sense represented as a branched composition chain, in principle the same thing as you find in those diagram explanations of Arrows.]