I would like to evaluate a simple computation graph. I was able to write the code to do so for a computation graph where every non-terminal node has two dependencies (and this can be trivially extended to any fixed number of dependencies)
{-# LANGUAGE ExistentialQuantification #-}
module Graph where
-- Have:
data Node a =
forall u v . CalculationNode { f :: u -> v -> a
, dependencies :: (Node u, Node v) }
| TerminalNode { value :: a }
eval :: Node a -> a
eval (CalculationNode f (d1, d2)) = f (eval d1) (eval d2)
eval (TerminalNode v) = v
three :: Node Int
three = TerminalNode 3
abcd :: Node String
abcd = TerminalNode "abcd"
seven :: Node Int
seven = CalculationNode (\ s i -> i + length s) (abcd, three)
The question is: how do I extend this code so that notes can take an arbitrary number of dependencies?
Something like:
data Node a =
forall u_1 u_2 ... u_n . CalculationNode { f :: u_1 -> u_2 -> ... -> u_n -> a
, dependencies :: (Node u_1, Node u_2, ... , Node u_n) }
| TerminalNode { value :: a }
eval :: Node a -> a
eval = ?
I suspect this requires some typefamily/hlist sorcery but I don't even know where to begin. Solutions and hints welcome.
Sure, with a bit of 'sorcery' this generalizes quite nicely:
{-# LANGUAGE PolyKinds, ExistentialQuantification, DataKinds, TypeOperators, TypeFamilies, GADTs #-}
import Data.Functor.Identity
type family (xs :: [*]) :-> (r :: *) :: * where
'[] :-> r = r
(x ': xs) :-> r = x -> (xs :-> r)
This type family represents n-ary functions. The definition is quite obvious, I think.
infixr 5 :>
data Prod (f :: k -> *) (xs :: [k]) where
Nil :: Prod f '[]
(:>) :: f x -> Prod f xs -> Prod f (x ': xs)
This datatype is a vector indexed a list of types. This is less obvious. You need to store a list of type variables in Node
somehow - but each type variable must have Node
applied to it. This formulation makes it simple:
data Node a
= forall vs . CalculationNode (vs :-> a) (Prod Node vs)
| TerminalNode a
Then a few helper functions:
appFn :: vs :-> a -> Prod Identity vs -> a
appFn z Nil = z
appFn f (x :> xs) = appFn (f $ runIdentity x) xs
mapProd :: (forall x . f x -> g x) -> Prod f xs -> Prod g xs
mapProd _ Nil = Nil
mapProd f (x :> xs) = f x :> mapProd f xs
and your eval
function is almost as simple as before:
eval :: Node a -> a
eval (TerminalNode a) = a
eval (CalculationNode fs as) = appFn fs $ mapProd (Identity . eval) as
The only thing that changes about your example is replacing tuples with Prod
constructors:
seven = CalculationNode (\s i -> i + length s) (abcd :> three :> Nil)