Search code examples
haskellgadt

Equality for GADTs which erase type parameter


I cannot implement an instance of Eq for the following typesafe DSL for expressions implemented with GADTs.

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

Expressions can be either of type Bool or Int. There are constructors for literals Bool and Num which have the corresponding types. Only Int expressions can be added up (constructor Plus). The condition in the If expression should have type Bool while both branches should have the same type. There is also an equality expression Equal whose operands should have the same type, and the type of the equality expression is Bool.

I have no problems implementing the interpreter eval for this DSL. It compiles and works like a charm:

eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y

However, I struggle to implement an instance of Eq for the DSL. I tried the simple syntactic equality:

instance Eq a => Eq (Expr a) where
  Num x == Num y = x == y
  Bool x == Bool y = x == y
  Plus x y == Plus x' y' = x == x' && y == y'
  If c t e == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = x == x' && y == y'
  _ == _ = False

It does not typecheck (with ghc 8.6.5), the error is the following:

[1 of 1] Compiling Main             ( Main.hs, Main.o )

Main.hs:17:35: error:
    • Could not deduce: a2 ~ a1
      from the context: (a ~ Bool, Eq a1)
        bound by a pattern with constructor:
                   Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
                 in an equation for ‘==’
        at Main.hs:17:3-11
      ‘a2’ is a rigid type variable bound by
        a pattern with constructor:
          Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
        in an equation for ‘==’
        at Main.hs:17:16-26
      ‘a1’ is a rigid type variable bound by
        a pattern with constructor:
          Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
        in an equation for ‘==’
        at Main.hs:17:3-11
      Expected type: Expr a1
        Actual type: Expr a2
    • In the second argument of ‘(==)’, namely ‘x'’
      In the first argument of ‘(&&)’, namely ‘x == x'’
      In the expression: x == x' && y == y'
    • Relevant bindings include
        y' :: Expr a2 (bound at Main.hs:17:25)
        x' :: Expr a2 (bound at Main.hs:17:22)
        y :: Expr a1 (bound at Main.hs:17:11)
        x :: Expr a1 (bound at Main.hs:17:9)
   |
17 |   Equal x y == Equal x' y' = x == x' && y == y'
   |  

I believe the reason is that the constructor Equal "forgets" the value of the type parameter a of its subexpressions and there is no way for the typechecker to ensure subexpressions x and y both have the same type Expr a.

I tried adding one more type parameter to Expr a to keep track of the type of subexpressions:

data Expr a b where
  Num :: Int -> Expr Int b
  Bool :: Bool -> Expr Bool b
  Plus :: Expr Int b -> Expr Int b -> Expr Int b
  If :: Expr Bool b -> Expr a b -> Expr a b -> Expr a b
  Equal :: Eq a => Expr a a -> Expr a a -> Expr Bool a

instance Eq a => Eq (Expr a b) where
  -- same implementation

eval :: Expr a b -> a
  -- same implementation

This approach does not seem scalable to me, once more constructors with subexpressions of different types are added.

All this makes me think that I do use GADTs incorrectly to implement this kind of DSL. Is there a way to implement Eq for this type? If not, what is the idiomatic way to express this kind of type constraint on the expressions?

Complete code:

{-# LANGUAGE GADTs #-}
 
module Main where

data Expr a where
  Num :: Int -> Expr Int
  Bool :: Bool -> Expr Bool
  Plus :: Expr Int -> Expr Int -> Expr Int
  If :: Expr Bool -> Expr a -> Expr a -> Expr a
  Equal :: Eq a => Expr a -> Expr a -> Expr Bool

instance Eq a => Eq (Expr a) where
  Num x == Num y = x == y
  Bool x == Bool y = x == y
  Plus x y == Plus x' y' = x == x' && y == y'
  If c t e == If c' t' e' = c == c' && t == t' && e == e'
  Equal x y == Equal x' y' = x == x' && y == y'
  _ == _ = False

eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y

main :: IO ()
main = do
  let expr1 = If (Equal (Num 13) (Num 42)) (Bool True) (Bool False)
  let expr2 = If (Equal (Num 13) (Num 42)) (Num 42) (Num 777)
  print (eval expr1)
  print (eval expr2)
  print (expr1 == expr1)

Solution

  • Your issue is that in

    Equal x y == Equal x' y' = ...
    

    it is possible that x and x' have different types. For example, Equal (Bool True) (Bool True) == Equal (Int 42) (Int 42) type checks, but we can't then simply compare Bool True == Int 42 as we might try to do in the Eq instance.

    Here are a few alternative solutions. The last one (generalizing == to eqExpr) seems the simplest to me, but the others are interesting as well.

    Use a singleton and compute types

    We start from your original type

    {-# LANGUAGE GADTs #-}
    module Main where
    
    data Expr a where
      Num :: Int -> Expr Int
      Bool :: Bool -> Expr Bool
      Plus :: Expr Int -> Expr Int -> Expr Int
      If :: Expr Bool -> Expr a -> Expr a -> Expr a
      Equal :: Eq a => Expr a -> Expr a -> Expr Bool
    

    and define a singleton GADT to represent the types you have

    data Ty a where
      TyInt  :: Ty Int
      TyBool :: Ty Bool
    

    We then prove that your types can only be Int or Bool, and how to compute them from the expression.

    tyExpr :: Expr a -> Ty a
    tyExpr (Num _)     = TyInt
    tyExpr (Bool _)    = TyBool
    tyExpr (Plus _ _)  = TyInt
    tyExpr (If _ t _)  = tyExpr t
    tyExpr (Equal _ _) = TyBool
    

    We can now exploit that and define the Eq instance.

    instance Eq (Expr a) where
      Num x     == Num y       = x == y
      Bool x    == Bool y      = x == y
      Plus x y  == Plus x' y'  = x == x' && y == y'
      If c t e  == If c' t' e' = c == c' && t == t' && e == e'
      Equal x y == Equal x' y' = case (tyExpr x, tyExpr x') of
         (TyInt,  TyInt ) -> x == x' && y == y'
         (TyBool, TyBool) -> x == x' && y == y'
         _                -> False
      _ == _ = False
    

    Use Typeable

    We slightly modify the original GADT:

    import Data.Typeable
      
    data Expr a where
      Num :: Int -> Expr Int
      Bool :: Bool -> Expr Bool
      Plus :: Expr Int -> Expr Int -> Expr Int
      If :: Expr Bool -> Expr a -> Expr a -> Expr a
      Equal :: (Typeable a, Eq a) => Expr a -> Expr a -> Expr Bool
    

    We can then try to cast the values to the right types: if the cast fails, we had two Equals among distinct types, so we can return False.

    instance Eq (Expr a) where
      Num x     == Num y       = x == y
      Bool x    == Bool y      = x == y
      Plus x y  == Plus x' y'  = x == x' && y == y'
      If c t e  == If c' t' e' = c == c' && t == t' && e == e'
      Equal x y == Equal x' y' = case cast (x,y) of
         Just (x2, y2) -> x2 == x' && y2 == y'
         Nothing       -> False
      _ == _ = False
    

    Generalize to heterogeneous equality

    We can use the original GADT:

    data Expr a where
      Num :: Int -> Expr Int
      Bool :: Bool -> Expr Bool
      Plus :: Expr Int -> Expr Int -> Expr Int
      If :: Expr Bool -> Expr a -> Expr a -> Expr a
      Equal :: Eq a => Expr a -> Expr a -> Expr Bool
    

    and write a heterogeneous equality test, that can work even if the two expressions haven't the same type:

    eqExpr :: Expr a -> Expr b -> Bool
    eqExpr (Num x)     (Num y)       = x == y
    eqExpr (Bool x)    (Bool y)      = x == y
    eqExpr (Plus x y)  (Plus x' y')  = eqExpr x x' && eqExpr y y'
    eqExpr (If c t e)  (If c' t' e') = eqExpr c c' && eqExpr t t' && eqExpr e e'
    eqExpr (Equal x y) (Equal x' y') = eqExpr x x' && eqExpr y y'
    eqExpr _           _             = False
    

    The Eq instance is then a special case.

    instance Eq (Expr a) where
      (==) = eqExpr
    

    A final note

    As pointed out by Joseph Sible in the comments, in all these approaches we do not need the Eq a context in the instances. We can simply remove it:

    instance {- Eq a => -} Eq (Expr a) where
       ...
    

    Further, in principle we do not even really need the Eq a in the definition of Equal, so we could simplify our GADT:

    data Expr a where
      Num :: Int -> Expr Int
      Bool :: Bool -> Expr Bool
      Plus :: Expr Int -> Expr Int -> Expr Int
      If :: Expr Bool -> Expr a -> Expr a -> Expr a
      Equal :: Expr a -> Expr a -> Expr Bool
    

    However, if we do that the definition of eval :: Expr a -> a becomes more complex in the Equal case, where we probably need to use something like tyExpr to infer the type, so that we can use ==.