I cannot implement an instance of Eq
for the following typesafe DSL for expressions implemented with GADTs.
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: Eq a => Expr a -> Expr a -> Expr Bool
Expressions can be either of type Bool
or Int
. There are constructors for literals Bool
and Num
which have the corresponding types. Only Int
expressions can be added up (constructor Plus
). The condition in the If
expression should have type Bool
while both branches should have the same type. There is also an equality expression Equal
whose operands should have the same type, and the type of the equality expression is Bool
I have no problems implementing the interpreter eval
for this DSL. It compiles and works like a charm:
eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y
However, I struggle to implement an instance of Eq
for the DSL. I tried the simple syntactic equality:
instance Eq a => Eq (Expr a) where
Num x == Num y = x == y
Bool x == Bool y = x == y
Plus x y == Plus x' y' = x == x' && y == y'
If c t e == If c' t' e' = c == c' && t == t' && e == e'
Equal x y == Equal x' y' = x == x' && y == y'
_ == _ = False
It does not typecheck (with ghc 8.6.5
), the error is the following:
[1 of 1] Compiling Main ( Main.hs, Main.o )
Main.hs:17:35: error:
• Could not deduce: a2 ~ a1
from the context: (a ~ Bool, Eq a1)
bound by a pattern with constructor:
Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
in an equation for ‘==’
at Main.hs:17:3-11
‘a2’ is a rigid type variable bound by
a pattern with constructor:
Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
in an equation for ‘==’
at Main.hs:17:16-26
‘a1’ is a rigid type variable bound by
a pattern with constructor:
Equal :: forall a. Eq a => Expr a -> Expr a -> Expr Bool,
in an equation for ‘==’
at Main.hs:17:3-11
Expected type: Expr a1
Actual type: Expr a2
• In the second argument of ‘(==)’, namely ‘x'’
In the first argument of ‘(&&)’, namely ‘x == x'’
In the expression: x == x' && y == y'
• Relevant bindings include
y' :: Expr a2 (bound at Main.hs:17:25)
x' :: Expr a2 (bound at Main.hs:17:22)
y :: Expr a1 (bound at Main.hs:17:11)
x :: Expr a1 (bound at Main.hs:17:9)
17 | Equal x y == Equal x' y' = x == x' && y == y'
I believe the reason is that the constructor Equal
"forgets" the value of the type parameter a
of its subexpressions and there is no way for the typechecker to ensure subexpressions x
and y
both have the same type Expr a
I tried adding one more type parameter to Expr a
to keep track of the type of subexpressions:
data Expr a b where
Num :: Int -> Expr Int b
Bool :: Bool -> Expr Bool b
Plus :: Expr Int b -> Expr Int b -> Expr Int b
If :: Expr Bool b -> Expr a b -> Expr a b -> Expr a b
Equal :: Eq a => Expr a a -> Expr a a -> Expr Bool a
instance Eq a => Eq (Expr a b) where
-- same implementation
eval :: Expr a b -> a
-- same implementation
This approach does not seem scalable to me, once more constructors with subexpressions of different types are added.
All this makes me think that I do use GADTs incorrectly to implement this kind of DSL. Is there a way to implement Eq
for this type? If not, what is the idiomatic way to express this kind of type constraint on the expressions?
Complete code:
module Main where
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: Eq a => Expr a -> Expr a -> Expr Bool
instance Eq a => Eq (Expr a) where
Num x == Num y = x == y
Bool x == Bool y = x == y
Plus x y == Plus x' y' = x == x' && y == y'
If c t e == If c' t' e' = c == c' && t == t' && e == e'
Equal x y == Equal x' y' = x == x' && y == y'
_ == _ = False
eval :: Expr a -> a
eval (Num x) = x
eval (Bool x) = x
eval (Plus x y) = eval x + eval y
eval (If c t e) = if eval c then eval t else eval e
eval (Equal x y) = eval x == eval y
main :: IO ()
main = do
let expr1 = If (Equal (Num 13) (Num 42)) (Bool True) (Bool False)
let expr2 = If (Equal (Num 13) (Num 42)) (Num 42) (Num 777)
print (eval expr1)
print (eval expr2)
print (expr1 == expr1)
Your issue is that in
Equal x y == Equal x' y' = ...
it is possible that x
and x'
have different types. For example, Equal (Bool True) (Bool True) == Equal (Int 42) (Int 42)
type checks, but we can't then simply compare Bool True == Int 42
as we might try to do in the Eq
Here are a few alternative solutions. The last one (generalizing ==
to eqExpr
) seems the simplest to me, but the others are interesting as well.
We start from your original type
module Main where
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: Eq a => Expr a -> Expr a -> Expr Bool
and define a singleton GADT to represent the types you have
data Ty a where
TyInt :: Ty Int
TyBool :: Ty Bool
We then prove that your types can only be Int
or Bool
, and how to compute them from the expression.
tyExpr :: Expr a -> Ty a
tyExpr (Num _) = TyInt
tyExpr (Bool _) = TyBool
tyExpr (Plus _ _) = TyInt
tyExpr (If _ t _) = tyExpr t
tyExpr (Equal _ _) = TyBool
We can now exploit that and define the Eq
instance Eq (Expr a) where
Num x == Num y = x == y
Bool x == Bool y = x == y
Plus x y == Plus x' y' = x == x' && y == y'
If c t e == If c' t' e' = c == c' && t == t' && e == e'
Equal x y == Equal x' y' = case (tyExpr x, tyExpr x') of
(TyInt, TyInt ) -> x == x' && y == y'
(TyBool, TyBool) -> x == x' && y == y'
_ -> False
_ == _ = False
We slightly modify the original GADT:
import Data.Typeable
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: (Typeable a, Eq a) => Expr a -> Expr a -> Expr Bool
We can then try to cast the values to the right types: if the cast fails, we had two Equal
s among distinct types, so we can return False
instance Eq (Expr a) where
Num x == Num y = x == y
Bool x == Bool y = x == y
Plus x y == Plus x' y' = x == x' && y == y'
If c t e == If c' t' e' = c == c' && t == t' && e == e'
Equal x y == Equal x' y' = case cast (x,y) of
Just (x2, y2) -> x2 == x' && y2 == y'
Nothing -> False
_ == _ = False
We can use the original GADT:
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: Eq a => Expr a -> Expr a -> Expr Bool
and write a heterogeneous equality test, that can work even if the two expressions haven't the same type:
eqExpr :: Expr a -> Expr b -> Bool
eqExpr (Num x) (Num y) = x == y
eqExpr (Bool x) (Bool y) = x == y
eqExpr (Plus x y) (Plus x' y') = eqExpr x x' && eqExpr y y'
eqExpr (If c t e) (If c' t' e') = eqExpr c c' && eqExpr t t' && eqExpr e e'
eqExpr (Equal x y) (Equal x' y') = eqExpr x x' && eqExpr y y'
eqExpr _ _ = False
The Eq
instance is then a special case.
instance Eq (Expr a) where
(==) = eqExpr
As pointed out by Joseph Sible in the comments, in all these approaches we do not need the Eq a
context in the instances. We can simply remove it:
instance {- Eq a => -} Eq (Expr a) where
Further, in principle we do not even really need the Eq a
in the definition of Equal
, so we could simplify our GADT:
data Expr a where
Num :: Int -> Expr Int
Bool :: Bool -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
Equal :: Expr a -> Expr a -> Expr Bool
However, if we do that the definition of eval :: Expr a -> a
becomes more complex in the Equal
case, where we probably need to use something like tyExpr
to infer the type, so that we can use ==