Search code examples
haskellgadt

Recursive replace in a GADT


Say I have the following GADT AST:

data O a b c where 
    Add ::  O a a a
    Eq :: O a b Bool
    --... more operations

data Tree a where 
    N :: (O a b c) -> Tree a -> Tree b -> Tree c
    L :: a -> Tree a

Now I want to construct a function that replaces all L(eave)s of type a in the Tree, something like this:

f :: a -> Tree b -> Tree b
f x (L a) | typeof x == typeof a = L x
f x (L a) = L a
f x (N o a b) = N o (f x a) (f x b)

Would it be possible to construct such a function? (using classes maybe?) Could it be done if changes are made to the GADTs?

I already have a typeof function: typeof :: a -> Type within a class.


Solution

  • The trick is to use type witnesses: http://www.haskell.org/haskellwiki/Type_witness

    data O a b c where 
        Add ::  O a a a
        Eq :: O a b Bool
    
    instance Show (O a b c) where
        show Add = "Add"
        show Eq = "Eq"
    
    data Tree a where 
        T :: (Typeable a, Typeable b, Typeable c) => (O a b c) -> Tree a -> Tree b -> Tree c
        L :: a -> Tree a
    
    instance (Show a) => Show (Tree a) where
        show (T o a b) = "(" ++ (show o) ++ " " ++ (show a) ++ " " ++ (show b) ++ ")"
        show (L a) = (show a)
    
    class (Show a) => Typeable a where
        witness :: a -> Witness a
    
    instance Typeable Int where
        witness _ = IntWitness
    
    instance Typeable Bool where
        witness _ = BoolWitness
    
    data Witness a where
        IntWitness :: Witness Int
        BoolWitness :: Witness Bool
    
    dynamicCast :: Witness a -> Witness b -> a -> Maybe b
    dynamicCast IntWitness  IntWitness a  = Just a
    dynamicCast BoolWitness BoolWitness a = Just a
    dynamicCast _ _ _ = Nothing
    
    replace :: (Typeable a, Typeable b) => a -> b -> b
    replace a b = case dynamicCast (witness a) (witness b) a of
        Just v  -> v
        Nothing -> b
    
    f :: (Typeable a, Typeable b) => b -> Tree a -> Tree a
    f x (L a) = L $ replace x a
    f x (T o a b) = T o (f x a) (f x b)