Search code examples
haskellpolymorphismtypeclassalgebraic-data-typesgadt

Polymorphism scenario in Haskell


I have written the following Haskell program to interpret basic math. I would like to add comparison and boolean operators in addition to mathematical operators. My question is how I should go about replacing the occurrences of Int with something that can handle either Int or Bool.

I considered expanding the Token type to have three types of operators, which would differ only in the type of the function ((Int -> Int -> Int), (Int -> Int -> Bool), and (Bool -> Bool -> Bool), but this seems like it would result in quite a bit of duplication, both in the type declaration, and in the pattern matching. Is there a way to do this with a type class?

type Precedence = Int
data Associativity = AssocL | AssocR
data Token = Operand Int | Operator String (Int -> Int -> Int) Associativity Precedence | ParenL | ParenR

instance Eq Token where
  Operator s1 _ _ _ == Operator s2 _ _ _  = s1 == s2
  Operand  x1       == Operand  x2        = x1 == x2
  ParenL            == ParenL             = True
  ParenR            == ParenR             = True
  _                 == _                  = False

evalMath :: String -> Int
evalMath = rpn . shuntingYard . tokenize

tokenize :: String -> [Token]
tokenize = map token . words
  where token s@"+" = Operator s (+) AssocL 2
        token s@"-" = Operator s (-) AssocL 2
        token s@"*" = Operator s (*) AssocL 3
        token s@"/" = Operator s div AssocL 3
        token s@"^" = Operator s (^) AssocR 4
        token "("   = ParenL
        token ")"   = ParenR
        token x     = Operand $ read x

shuntingYard :: [Token] -> [Token]
shuntingYard = finish . foldl shunt ([], [])
  where finish (tokens, ops) = (reverse tokens) ++ ops
        shunt (tokens, ops) token@(Operand _)        = (token:tokens, ops)
        shunt (tokens, ops) token@(Operator _ _ _ _) = ((reverse higher) ++ tokens, token:lower)
          where (higher, lower) = span (higherPrecedence token) ops
                higherPrecedence (Operator _ _ AssocL prec1) (Operator _ _ _ prec2) = prec1 <= prec2
                higherPrecedence (Operator _ _ AssocR prec1) (Operator _ _ _ prec2) = prec1 < prec2
                higherPrecedence (Operator _ _ _ _)          ParenL                 = False
        shunt (tokens, ops) ParenL = (tokens, ParenL:ops)
        shunt (tokens, ops) ParenR = ((reverse afterParen) ++ tokens, tail beforeParen)
          where (afterParen, beforeParen) = break (== ParenL) ops

rpn :: [Token] -> Int
rpn = head . foldl rpn' []
  where rpn' (x:y:ys) (Operator _ f _ _) = (f x y):ys
        rpn' xs (Operand x) = x:xs

Solution

  • This ended up being far simpler than I thought. Both of the answers I received helped, but neither pointed me directly to the solution. The GADT thing is overkill for what I was trying to do.

    All you really need to do in this kind of situation is to wrap the operand in an option type and make a simple way to lift your functions to operate on that type. By making the Token type parameterized by the operand type (Result below) I was able to generalize the algorithm quite pleasingly.

    import ShuntingYard
    
    data Result = I Int | B Bool deriving (Eq)
    
    instance Show Result where
      show (I x) = show x
      show (B x) = show x
    
    evalMath :: String -> Result
    evalMath = rpn . shuntingYard . tokenize
    
    liftIII f (I x) (I y) = I $ f x y
    liftIIB f (I x) (I y) = B $ f x y
    liftBBB f (B x) (B y) = B $ f x y
    
    tokenize :: String -> [Token Result]
    tokenize = map token . words
      where token s@"&&" = Operator s (liftBBB (&&)) AssocL 0
            token s@"||" = Operator s (liftBBB (||)) AssocL 0
            token s@"="  = Operator s (liftIIB (==)) AssocL 1
            token s@"!=" = Operator s (liftIIB (/=)) AssocL 1
            token s@">"  = Operator s (liftIIB (<))  AssocL 1
            token s@"<"  = Operator s (liftIIB (>))  AssocL 1
            token s@"<=" = Operator s (liftIIB (>=)) AssocL 1
            token s@">=" = Operator s (liftIIB (<=)) AssocL 1
            token s@"+"  = Operator s (liftIII (+))  AssocL 2
            token s@"-"  = Operator s (liftIII (-))  AssocL 2
            token s@"*"  = Operator s (liftIII (*))  AssocL 3
            token s@"/"  = Operator s (liftIII div)  AssocL 3
            token s@"^"  = Operator s (liftIII (^))  AssocR 4
            token "("    = ParenL
            token ")"    = ParenR
            token "f"    = Operand $ B False
            token "t"    = Operand $ B True
            token x      = Operand $ I $ read x
    

    Where the ShuntingYard module is defined as:

    module ShuntingYard ( Associativity(AssocL, AssocR)
                        , Token(Operand, Operator, ParenL, ParenR)
                        , shuntingYard
                        , rpn) where 
    
    type Precedence = Int
    data Associativity = AssocL | AssocR
    data Token a = Operand a | Operator String (a -> a -> a) Associativity Precedence | ParenL | ParenR
    
    instance (Show a) => Show (Token a) where
      show (Operator s _ _ _) = s
      show (Operand x)        = show x
      show ParenL             = "("
      show ParenR             = ")"
    
    instance (Eq a) => Eq (Token a) where
      Operator s1 _ _ _ == Operator s2 _ _ _  = s1 == s2
      Operand  x1       == Operand  x2        = x1 == x2
      ParenL            == ParenL             = True
      ParenR            == ParenR             = True
      _                 == _                  = False
    
    shuntingYard :: (Eq a) => [Token a] -> [Token a]
    shuntingYard = finish . foldl shunt ([], [])
      where finish (tokens, ops) = (reverse tokens) ++ ops
            shunt (tokens, ops) token@(Operand _)        = (token:tokens, ops)
            shunt (tokens, ops) token@(Operator _ _ _ _) = ((reverse higher) ++ tokens, token:lower)
              where (higher, lower) = span (higherPrecedence token) ops
                    higherPrecedence (Operator _ _ AssocL prec1) (Operator _ _ _ prec2) = prec1 <= prec2
                    higherPrecedence (Operator _ _ AssocR prec1) (Operator _ _ _ prec2) = prec1 < prec2
                    higherPrecedence (Operator _ _ _ _)          ParenL                 = False
            shunt (tokens, ops) ParenL = (tokens, ParenL:ops)
            shunt (tokens, ops) ParenR = ((reverse afterParen) ++ tokens, tail beforeParen)
              where (afterParen, beforeParen) = break (== ParenL) ops
    
    rpn :: [Token a] -> a
    rpn = head . foldl rpn' []
      where rpn' (x:y:ys) (Operator _ f _ _) = (f x y):ys
            rpn' xs (Operand x) = x:xs