Search code examples
haskellexistential-typetype-familieshlist

Evaluating a strongly typed computation graph with arbitrary number of dependencies per node


I would like to evaluate a simple computation graph. I was able to write the code to do so for a computation graph where every non-terminal node has two dependencies (and this can be trivially extended to any fixed number of dependencies)

{-# LANGUAGE ExistentialQuantification #-}

module Graph where

-- Have:
data Node a =
    forall u v . CalculationNode { f :: u -> v -> a
                                 , dependencies :: (Node u, Node v) }
  | TerminalNode { value :: a }

eval :: Node a -> a
eval (CalculationNode f (d1, d2)) = f (eval d1) (eval d2)
eval (TerminalNode v) = v

three :: Node Int
three = TerminalNode 3

abcd :: Node String
abcd = TerminalNode "abcd"

seven :: Node Int
seven = CalculationNode (\ s i -> i + length s) (abcd, three)

The question is: how do I extend this code so that notes can take an arbitrary number of dependencies?

Something like:

data Node a =
    forall u_1 u_2 ... u_n . CalculationNode { f :: u_1 -> u_2 -> ... -> u_n -> a
                                             , dependencies :: (Node u_1, Node u_2, ... , Node u_n) }
  | TerminalNode { value :: a }

eval :: Node a -> a
eval = ?

I suspect this requires some typefamily/hlist sorcery but I don't even know where to begin. Solutions and hints welcome.


Solution

  • Sure, with a bit of 'sorcery' this generalizes quite nicely:

    {-# LANGUAGE PolyKinds, ExistentialQuantification, DataKinds, TypeOperators, TypeFamilies, GADTs #-} 
    
    import Data.Functor.Identity 
    
    type family (xs :: [*]) :-> (r :: *) :: * where 
      '[] :-> r = r 
      (x ': xs) :-> r = x -> (xs :-> r) 
    

    This type family represents n-ary functions. The definition is quite obvious, I think.

    infixr 5 :>
    data Prod (f :: k -> *) (xs :: [k]) where 
      Nil :: Prod f '[] 
      (:>) :: f x -> Prod f xs -> Prod f (x ': xs) 
    

    This datatype is a vector indexed a list of types. This is less obvious. You need to store a list of type variables in Node somehow - but each type variable must have Node applied to it. This formulation makes it simple:

    data Node a 
      = forall vs . CalculationNode (vs :-> a) (Prod Node vs)
      | TerminalNode a  
    

    Then a few helper functions:

    appFn :: vs :-> a -> Prod Identity vs -> a 
    appFn z Nil = z 
    appFn f (x :> xs) = appFn (f $ runIdentity x) xs 
    
    mapProd :: (forall x . f x -> g x) -> Prod f xs -> Prod g xs 
    mapProd _ Nil = Nil 
    mapProd f (x :> xs) = f x :> mapProd f xs 
    

    and your eval function is almost as simple as before:

    eval :: Node a -> a 
    eval (TerminalNode a) = a 
    eval (CalculationNode fs as) = appFn fs $ mapProd (Identity . eval) as
    

    The only thing that changes about your example is replacing tuples with Prod constructors:

    seven = CalculationNode (\s i -> i + length s) (abcd :> three :> Nil)