Search code examples
c#returnexpressionantlrantlr4

Return Statement is not evaluating and Expression is only evaluating rules under it


Good day i am working on a interpreter based on antlr4 i am adding function feature to the language, without return statement function work but with return statement is not working or i can say return statement is not working as is returning Invalid Expression.

Here is my complete grammar and Visitor class implementation.

DECLARE : 'declare';
SET : 'set';
TO : 'to';
SHOW : 'show';
SHOWLINE : 'showline';
REPEAT            :   'repeat';
WHILE             :   'while';
TIMES             :   'times';
END_REPEAT        :   'end repeat';
IF                :   'if';
THEN              :   'then';
ELSE              :   'else';
ELSE_IF           :   'else if';
END_IF            :   'end if';
INPUT  : 'input';
FOR               :   'for';
END_FOR           :   'end for';
FROM              :   'from';
STEP              :   'step';
BY                :   'by';
GENERATE          :   'generate';
STOP              :   'stop';
FUNCTION          : 'function';
RETURN            : 'return';
END_FUNCTION      : 'end function';
ITERATE           : 'iterate';
IN                : 'in';
OVER              : 'over';
END_ITERATE       : 'end iterate';
CHOOSE            : 'choose';
WHEN              : 'when';
OTHERWISE         : 'otherwise:';
END_CHOOSE        : 'end choose';
TRUE              : 'true';
FALSE             : 'false';
LPAREN : '(';
RPAREN : ')';
LBRACKET : '[';
RBRACKET : ']';
LBRACE            :   '{';
RBRACE            :   '}';
COMMA  : ',';
LESS_THAN         :   '<';
LESS_THAN_EQUAL   :   '<=';
GREATER_THAN      :   '>';
GREATER_THAN_EQUAL:   '>=';
EQUAL             :   '==';
NOT_EQUAL         :   '!=';
MUL : '*';
DIV : '/';
MODULO : 'remind';
POW    : '^';
PLUS : '+';
MINUS : '-';
SEMICOLON : ';';
COLON     : ':';
AND : 'and';
OR : 'or';
NOT : 'not';
NEWLINE : '\r'? '\n';
INDENT            :   [\t]+;
STRING : '"' ( '\\' . | ~[\\"] )* '"' | '\'' ( '\\' . | ~[\\'] )* '\'' ;
NUMBER : ('+' | '-')? [0-9]+ ('.' [0-9]+)?;
ID : [a-zA-Z] [a-zA-Z0-9_]*;
COMMENT : '//' ~[\r\n]* -> skip;
WS : [ \t\r\n]+ -> skip;
SPACE : ' ' -> skip;
program : statement_list* | EOF;

statement_list : statement+ (NEWLINE|SEMICOLON)*;

statement : declareStatement
          | setStatement
          | assignStatement
          | ifStatement
          | inputStatement
          | forStatement
          | generateStatement
          | showStatement
          | repeatStatement
          | repeatTimeStatement
          | showlnStatement
          | iterateStatement
          | chooseStatement
          | functionDecl
          | functionCall
          ;
declareStatement : DECLARE (ID | arrayAccess) (COMMA (ID | arrayAccess))*;
setStatement     : (SET ID | ID) (TO expression | TO arrayAccess | TO arrayElement);
assignStatement  : (ID | arrayAccess) TO expression;
showStatement    : SHOW expression;
ifStatement      : IF expression THEN statement_list* (elseifStatement)* (elseStatement)? END_IF;
elseifStatement  : ELSE_IF expression THEN statement_list*;
elseStatement    : ELSE statement_list*;
inputStatement   : (SET ID | ID) TO INPUT LPAREN expression RPAREN;
repeatStatement  : REPEAT WHILE LPAREN expression RPAREN statement_list* END_REPEAT;
repeatTimeStatement : REPEAT expression TIMES statement_list* END_REPEAT;
showlnStatement  : SHOWLINE LPAREN RPAREN;
forStatement : FOR ID FROM expression TO expression (STEP expression)? (NEWLINE)? (statement_list)* END_FOR;
generateStatement : GENERATE ID FROM expression TO expression (BY expression)?  (statement_list)* STOP;
iterateStatement  : ITERATE LPAREN ID OVER iteratable RPAREN (NEWLINE)? statement_list* END_ITERATE;
iteratable       : expression;
chooseStatement  : CHOOSE expression (NEWLINE)? whenStatement+ defaultStatement? END_CHOOSE;
whenStatement    : WHEN expression COLON statement+ (NEWLINE)?;
defaultStatement : OTHERWISE statement+;
functionDecl     : FUNCTION ID LPAREN param_list RPAREN (NEWLINE)? statement_list* returnStatement? END_FUNCTION;
returnStatement  : RETURN expression? (NEWLINE|SEMICOLON)?;
param_list       : (ID (COMMA ID)*)?;
functionCall     : ID LPAREN arg_list RPAREN;
arg_list         : (expression (COMMA expression)*)?;
arrayAccess      : ID LBRACKET index  RBRACKET;
arrayElement     : LBRACKET (expression (COMMA expression)*) RBRACKET;
index            : expression;

expression       :
                  LPAREN expression RPAREN
                 | expression (MUL | DIV | MODULO | POW) expression
                 | expression (PLUS | MINUS) expression
                 | expression (LESS_THAN_EQUAL | GREATER_THAN_EQUAL | LESS_THAN |  GREATER_THAN) expression
                 | expression (EQUAL | NOT_EQUAL) expression
                 | expression AND expression
                 | expression OR expression
                 | NOT expression
                 | functionCall
                 | arrayAccess
               
                 | TRUE
                 | FALSE
                 |STRING
                 | NUMBER
                 | ID
                 ;

The Visitor class implementation

private bool _hasReturn = false;
        public override object VisitFunctionDecl(EasyBiteParser.FunctionDeclContext context)
        {
            var functionName = context.ID().GetText();
            var functionParams = context.param_list() != null ? context.param_list().ID().Select(id => id.GetText()).ToList() : null;
            var functionBody = context.statement_list();
            Func<List<object>, object> functionImpl = (args) =>
            {
                // Create a new local scope for the function
                var localVariables = new Dictionary<string, object>();

                // Bind the function arguments to the local scope
                if (functionParams != null && args != null)
                {
                    for (int i = 0; i < functionParams.Count; i++)
                    {
                        localVariables[functionParams[i]] = args[i];
                    }
                }

                // Push the local variable dictionary onto the stack
                localVariablesStack.Push(localVariables);
                
                // Execute the function body statements
                try
                {
                    _hasReturn = false;
                    foreach (var statement in functionBody)
                    {
                        Visit(statement);

                    }
                }
                catch (ReturnException e)
                {
                    
                    // If a return statement was encountered, return the specified value
                    _hasReturn = true;
                   
                    return e.Value;
                }

                // Pop the local variable dictionary off the stack
                localVariablesStack.Pop();

                // If no return statement was encountered, return null
                if (!_hasReturn)
                {
                    return null;
                }


                return null;
            };

            // Add the function to the global variable dictionary
            functionTable[functionName] = functionImpl;

            return null;
        }

        public override object VisitFunctionCall(EasyBiteParser.FunctionCallContext context)
        {
            // Get the name of the function and its implementation
            var functionName = context.ID().GetText();
            var functionImpl = (Func<List<object>, object>)functionTable[functionName];

            // Get the arguments to the function
            var functionParams = context.arg_list() != null ? context.arg_list().expression().Select(expr => Visit(expr)).ToList() : null;

            // Create a new local scope for the function
            var localVariables = new Dictionary<string, object>();

            // Push the local scope onto the stack
            localVariablesStack.Push(localVariables);

            // Bind the function arguments to the local scope
            if (functionParams != null && functionParams.Count > 0)
            {
                for (int i = 0; i < functionParams.Count; i++)
                {
                    localVariables[Convert.ToString(functionParams[i])] = functionParams[i];
                }
            }

            object result = null;
            try
            {
                // Call the function implementation
                result = functionImpl(functionParams);

            }
            catch (ReturnException ex)
            {
                // Extract the value from the ReturnException and return it
                result = ex.Value;
                
            }
            // Pop the local scope off the stack
            localVariablesStack.Pop();

            return result;
        }
        public override object VisitReturnStatement(EasyBiteParser.ReturnStatementContext context)
        {
           
            var value = context.expression() != null ? Visit(context.expression()) : null;
            
            throw new ReturnException(value);

        }

Thank you. This is only snippet of the function implementation. I wanted to post the expression but is too long as stackoverflow only allowed 30000 characters.


Solution

  • Looking at your functionDecl rule:

    functionDecl     : FUNCTION ID LPAREN param_list RPAREN (NEWLINE)? statement_list* returnStatement? END_FUNCTION;
    

    I don't see the returnStatement being used anywhere in your visitor. Try something like this (untested!):

    // Existing code
    _hasReturn = false;
    foreach (var statement in functionBody)
    {
        Visit(statement);
    }
    
    // New code
    var returnStatement = context.returnStatement();
    
    if (returnStatement != null)
    {
        Visit(returnStatement);
    }
    

    But I must admit, the logic you got in place is rather hacky: relying on exceptions to get the return value of a method... Much better is to let the Visit(...) return the expression instead:

    foreach (var statement in functionBody)
    {
        Visit(statement);
    }
    
    var returnStatement = context.returnStatement();
    
    if (returnStatement != null)
    {
        return Visit(returnStatement);
    }
    
    // TODO return some sort of "void" value
    

    EDIT

    How about something like this:

    using System;
    using System.Collections.Generic;
    using System.Linq;
    using Antlr4.Runtime;
    
    public class Program
    {
        static void Main(string[] args)
        {
            var source = @"function add(a, b)
               return a + b
             end function
    
             add(40, 2);";
            
            var lexer = new EasyBiteLexer(new AntlrInputStream(source));
            var parser = new EasyBiteParser(new CommonTokenStream(lexer));
            var result = new Interpreter().Visit(parser.program());
            
            Console.WriteLine($"{result}");
        }
    }
    
    class Interpreter : EasyBiteBaseVisitor<object>
    {
        private readonly Dictionary<string, Function> functionDefinitions = new();
        public readonly Dictionary<string, int> variables = new();
        
        // statement_list* | EOF;
        public override object VisitProgram(EasyBiteParser.ProgramContext context)
        {
            object returnValue = null;
    
            foreach (var statement in context.statement_list() ?? Array.Empty<EasyBiteParser.Statement_listContext>())
            {
                returnValue = this.Visit(statement);
            }
            
            return returnValue;
        }
    
        // statement+ (NEWLINE|SEMICOLON)*
        public override object VisitStatement_list(EasyBiteParser.Statement_listContext context)
        {
            object returnValue = null;
    
            foreach (var statement in context.statement() ?? Array.Empty<EasyBiteParser.StatementContext>())
            {
                returnValue = this.Visit(statement);
            }
            
            return returnValue;
        }
    
        // FUNCTION ID LPAREN param_list RPAREN (NEWLINE)? statement_list* returnStatement? END_FUNCTION;
        public override object VisitFunctionDecl(EasyBiteParser.FunctionDeclContext context)
        {
            var parameters = context.param_list().ID().Select(i => i.GetText());
            this.functionDefinitions[context.ID().GetText()] = new Function(parameters, context.statement_list(), context.returnStatement());
            
            return null;
        }
    
        // ID LPAREN arg_list RPAREN
        public override object VisitFunctionCall(EasyBiteParser.FunctionCallContext context)
        {
            var function = this.functionDefinitions[context.ID().GetText()];
            var values = context.arg_list().expression().Select(e => (int)this.Visit(e));
            
            return function.Invoke(values.ToArray());
        }
    
        // RETURN expression? (NEWLINE|SEMICOLON)?
        public override object VisitReturnStatement(EasyBiteParser.ReturnStatementContext context)
        {
            return context.expression() == null ? null : base.Visit(context.expression());
        }
    
        // | expression (PLUS | MINUS) expression
        // | NUMBER
        // | ID
        public override object VisitExpression(EasyBiteParser.ExpressionContext context)
        {
            if (context.PLUS() != null)
                return (int)base.Visit(context.expression()[0]) + (int)base.Visit(context.expression()[1]);
            if (context.MINUS() != null)
                return (int)base.Visit(context.expression()[0]) - (int)base.Visit(context.expression()[1]);
            if (context.NUMBER() != null)
                return int.Parse(context.NUMBER().GetText());
            if (context.ID() != null)
                return this.variables[context.ID().GetText()];
    
            throw new Exception($"Not implemented: {context.GetText()}");
        }
    }
    
    class Function
    {
        private readonly List<string> parameters;
        private readonly EasyBiteParser.Statement_listContext[] statementList;
        private readonly EasyBiteParser.ReturnStatementContext returnStatement;
    
        public Function(IEnumerable<string> parameters, EasyBiteParser.Statement_listContext[] statementList, EasyBiteParser.ReturnStatementContext returnStatement)
        {
            this.parameters = parameters.ToList();
            this.statementList = statementList;
            this.returnStatement = returnStatement;
        }
    
        public object Invoke(int[] values)
        {
            var interpreter = new Interpreter();
         
            // Assume parameters.Count == values.Count
            for (var i = 0; i < this.parameters.Count; i++)
            {
                interpreter.variables[this.parameters[i]] = values[i];
            }
            
            object returnValue = null;
    
            foreach (var statement in statementList ?? Array.Empty<EasyBiteParser.Statement_listContext>())
            {
                returnValue = interpreter.Visit(statement);
            }
    
            return this.returnStatement == null ? returnValue : interpreter.Visit(this.returnStatement);
        }
    }
    

    IMO much better with the return values of the Visit... methods. And a separated Function class.