Search code examples
haskellcompiler-constructionpattern-matchingrecursive-descent

Is there a more maintainable way to process my datatype?


I have productions for a recursive-descent parser defined with the following datatype:

data CST 
    = Program CST CST
    | Block CST CST CST 
    | StatementList CST CST
    | EmptyStatementList
    | Statement CST
    | PrintStatement CST CST CST CST
    | AssignmentStatement CST CST CST
    | VarDecl CST CST
    | WhileStatement CST CST CST 
    | IfStatement CST CST CST 
    | Expr CST
    | IntExpr1 CST CST CST 
    | IntExpr2 CST
    | StringExpr CST CST CST
    | BooleanExpr1 CST CST CST CST CST
    | BooleanExpr2 CST 
    | Id CST
    | CharList CST CST 
    | EmptyCharList
    | Type CST 
    | Character CST
    | Space CST
    | Digit CST
    | BoolOp CST
    | BoolVal CST
    | IntOp CST
    | TermComponent Token
    | ErrorTermComponent (Token, Int)
    | NoInput

As the datatype name implies, the datatype constructs a concrete syntax tree. I'm wondering if there is a more maintainable way to pattern match over this type. For example, to trace execution of parse calls I have the following:

checkAndPrintParse :: CST -> IO ()
checkAndPrintParse (Program c1 c2) = do
    putStrLn "Parser: parseProgram" 
    checkAndPrintParse c1
    checkAndPrintParse c2
checkAndPrintParse (Block c1 c2 c3) = do
    putStrLn "Parser: parseBlock"
    checkAndPrintParse c1
    checkAndPrintParse c2
    checkAndPrintParse c3
checkAndPrintParse (StatementList c1 c2) = do
    putStrLn "Parser: parseStatementList"
    checkAndPrintParse c1
    checkAndPrintParse c2

and so on. I've looked into the fix function/pattern, but I'm not sure if it is applicable here.


Solution

  • Use generic-deriving to get the name of a constructor:

    • Derive Generic (from GHC.Generics)
    • Call conNameOf :: CSTF -> String (from Generics.Deriving)

    Use recursion-schemes to traverse a recursive type:

    • Derive the base functor of a recursive type with makeBaseFunctor. The base functor of CST, called CSTF, is a parameterized type that has the same shape as CST, but where recursive occurences of CST are replaced with the type parameter.
    • Learn to use cata (it may be a bit mind bending at the beginning). In this case we want to recursively construct an IO () action from a CST, i.e., a function CST -> IO (). For that, the type of cata becomes (CSTF (IO ()) -> IO ()) -> CST -> IO () (with t ~ CST and a ~ IO ()), where the first argument defines the body of the resulting recursive function, and the results of recursive calls are placed in the fields of the base functor.

    So, if your goal is to write a recursive function checkAndPrintParse with one case like:

    checkAndPrintParse (Program c1 c2) = do
      putStrLn "Parser: parseProgram" 
      checkAndPrintParse c1
      checkAndPrintParse c2
    

    cata will put the results of its recursive calls on c1 and c2 in place of those fields:

    -- goal: find f such that   cata f = checkAndPrintParse
    
    -- By definition of cata
    cata f (Program c1 c2) = f (ProgramF (cata f c1) (cata f c2))
    
    -- By the goal and the definition of checkAndPrintParse
    cata f (Program c1 c2) = checkAndPrintParse (Program c1 c2) = do
      putStrLn "Parser: parseProgram" 
      checkAndPrintParse c1
      checkAndPrintParse c2
    

    Therefore

    f (ProgramF (cata f c1) (cata f c2)) = do
      putStrLn "Parser: parseProgram"
      cata f c1
      cata f c2
    

    abstract cata f c1 and cata f c2

    f (ProgramF x1 x2) = do
      putStrLn "Parser: parserProgram"
      x1 >> x2
    

    Recognize a fold (in the Foldable sense)

    f t@(ProgramF _ _) = do
      putStrLn "Parser: parserProgram"
      sequence_ t
    

    Generalize again

    f t = do
      putStrLn $ "Parser: " ++ conNameOf t  -- Prints "ProgramF" instead of "parserProgram"... *shrugs*
      sequence_ t
    

    That's the argument we give to cata.


    {-# LANGUAGE DeriveGeneric #-}
    {-# LANGUAGE TypeFamilies #-}
    {-# LANGUAGE DeriveFunctor #-}
    {-# LANGUAGE DeriveFoldable #-}
    {-# LANGUAGE DeriveTraversable #-}
    {-# LANGUAGE StandaloneDeriving #-}
    {-# LANGUAGE TemplateHaskell #-}
    
    import GHC.Generics
    import Generics.Deriving (conNameOf)
    import Data.Functor.Foldable
    import Data.Functor.Foldable.TH (makeBaseFunctor)
    
    data CST 
        = Program CST CST
        | Block CST CST CST 
        | StatementList CST CST
        | EmptyStatementList
        | Statement CST
        | PrintStatement CST CST CST CST
        | AssignmentStatement CST CST CST
        | VarDecl CST CST
        | WhileStatement CST CST CST 
        | IfStatement CST CST CST 
        | Expr CST
        | IntExpr1 CST CST CST 
        | IntExpr2 CST
        | StringExpr CST CST CST
        | BooleanExpr1 CST CST CST CST CST
        | BooleanExpr2 CST 
        | Id CST
        | CharList CST CST 
        | EmptyCharList
        | Type CST 
        | Character CST
        | Space CST
        | Digit CST
        | BoolOp CST
        | BoolVal CST
        | IntOp CST
        | TermComponent Token
        | ErrorTermComponent (Token, Int)
        | NoInput
        deriving Generic
    
    data Token = Token
    
    makeBaseFunctor ''CST
    
    deriving instance Generic (CSTF a)
    
    checkAndPrintParse :: CST -> IO ()
    checkAndPrintParse = cata $ \t -> do
      putStrLn $ "Parser: " ++ conNameOf t
      sequence_ t
    
    main = checkAndPrintParse $
      Program (Block NoInput NoInput NoInput) (Id NoInput)
    

    Output:

    Parser: ProgramF
    Parser: BlockF
    Parser: NoInputF
    Parser: NoInputF
    Parser: NoInputF
    Parser: IdF
    Parser: NoInputF