Search code examples
pythonabstract-syntax-treetruthtable

Python Generate truthtable from expression containing Numeric value and comparators


I am trying to generate test cases from a boolean expression which may contain numerical values and comparators, e.g.:

(10 < a) and (c == 5)

truth tables seems like a good approach and generation output might look like this:

>  a,  c, result
> 11,  5,   True,
> 11,  4,  False,
>  9,  5,  False,
>  9,  4,  False

It is pretty straight forward to generate a truth table when expression only contains boolean operators using ast.nodevisitor. I ended up with something like this which works nicely.

class Expr():

    def __init__(self, expr):
        self.tree = ast.parse(expr)
        self.expr=expr
        self.vars = self.allVariables().visit(self.tree)
        self.generateTruthTable()

    def generateTruthTable(self):
        NO_GLOBALS = {'__builtins__': {}}
        self.truthtable=dict()
        for i, vals in enumerate(product([True,False],repeat=len(self.vars))):
            self.truthtable[i] = dict()
            self.truthtable[i]['inputs'] = dict(zip(self.vars, vals))
            self.truthtable[i]['expectation'] = eval(self.expr, NO_GLOBALS, self.truthtable[i]['inputs'])

    class allVariables(ast.NodeVisitor):
        def visit_Module(self, node):
            self.names = set()
            self.generic_visit(node)
            return sorted(self.names)

        def visit_Name(self, node):
            self.names.add(node.id)

now I am deeply struggling in how to perform this same generation whatever the operators are. it looks like recursion is key here. is there a way to perform such generation using python AST ? thank you.


Solution

  • I found a solution to the problem, I just had to constrain the usage to prevent that numerical value end to the right of an operation and also limited the possibility in term of operators.

    import ast
    import operator
    from itertools import product
    
    class OperatorError(Exception):
        pass
    
    class Expr():
    
        BOOLOP_SYMBOLS = (
            ast.And,
            ast.Or
        )
    
        UNARYOP_SYMBOLS = (
            ast.Not
        )
    
        BINOP_SYMBOLS = (
            ast.BitAnd,
            ast.BitXor,
            ast.Pow
        )
    
        CMPOP_SYMBOLS = (
            ast.Eq,
            ast.Gt,
            ast.GtE,
            ast.Lt,
            ast.LtE,
            ast.NotEq
        )
    
        def __init__(self, expr):
            self.tree = ast.parse(expr)
            self.expr=expr
            self.vars = self.SyntaxChecker().visit(self.tree)
            self.tt = self.generateTruthTable()
    
        def generateTruthTable(self):
            NO_GLOBALS = {'__builtins__': {}}
            truthtable=dict()
            for i, vals in enumerate(product([True,False],repeat=len(self.vars.keys()))):
                truthtable[i] = dict()
                vals_convert=[]
                for j,k in enumerate(self.vars.iterkeys()):
                    vals_convert.append(self.vars[k][vals[j]])
                truthtable[i]['inputs'] = dict(zip(self.vars.keys(), vals_convert))
                truthtable[i]['expectation'] = eval(self.expr, NO_GLOBALS, truthtable[i]['inputs'])
            return truthtable
    
        class SyntaxChecker(ast.NodeVisitor):
    
            def visit_Module(self, module):
                self.vars = dict()
                self.generic_visit(module)
                return self.vars
    
            def visit_Expr(self, expr):
                return self.visit(expr.value)
    
            def visit_BoolOp(self, boolop):
                if isinstance(boolop.op, Expr.BOOLOP_SYMBOLS):
                    left,right = map(self.visit, boolop.values)
                    return [left, right]
                else:
                    raise OperatorError(Expr.BOOLOP_SYMBOLS[boolop.op] + ' is not supported.')
    
            def visit_BinOp(self, binop):
                if isinstance(binop.op, Expr.BINOP_SYMBOLS):
                    left , right = map(self.visit, [binop.left, binop.right])
                    if isinstance(binop.op, ast.BitAnd):
                        self.vars[left] = {True: right, False: 0}
                    if isinstance(binop.op, ast.BitXor):
                        self.vars[left] = {True: 0, False: right}
                    if isinstance(binop.op, ast.Pow):
                        return left ** right
                else:
                    raise OperatorError(Expr.BINOP_SYMBOLS[binop.op] + ' is not supported.')
    
            def visit_Compare(self, cmpop):
                if isinstance(cmpop.ops[0], Expr.CMPOP_SYMBOLS):
                    left = self.visit(cmpop.left)
                    right = self.visit(cmpop.comparators[0])
                    if isinstance(cmpop.ops[0], ast.Gt):
                        self.vars[left] = {True: right + 1, False: right - 1}
                    elif isinstance(cmpop.ops[0], ast.GtE):
                        self.vars[left] = {True: right, False: right - 1}
                    elif isinstance(cmpop.ops[0], ast.Lt):
                        self.vars[left] = {True: right-1, False: right + 1}
                    elif isinstance(cmpop.ops[0], ast.LtE):
                        self.vars[left] = {True: right, False: right + 1}
                    elif isinstance(cmpop.ops[0], ast.Eq):
                        self.vars[left] = {True: right, False: right + 1}
                    elif isinstance(cmpop.ops[0], ast.NotEq):
                        self.vars[left] = {True: right + 1, False: right}
                else:
                    raise OperatorError(Expr.CMPOP_SYMBOLS[cmpop.ops] + ' is not supported.')
    
            def visit_UnaryOp(self, unaryop):
                if isinstance(unaryop.op, Expr.UNARYOP_SYMBOLS):
                    right= self.visit(unaryop.operand)
                    self.vars[right] = {True: False, False: True}
                else:
                    raise OperatorError(Expr.UNARYOP_SYMBOLS[unaryop.op] + ' is not supported.')
    
            def visit_Num(self, num):
                return num.n
    
            def visit_Name(self, node):
                self.vars[node.id] = {True: True, False: False}
                return node.id
    

    with the example:

    (a < 10) and (c == 5)

    I can generate the following output:

    { 0: {'inputs': {'a': 9, 'c': 5}, 'expectation': True}, 1: {'inputs': {'a': 9, 'c': 6}, 'expectation': False}, 2: {'inputs': {'a': 11, 'c': 5}, 'expectation': False}, 3: {'inputs': {'a': 11, 'c': 6}, 'expectation': False} }

    feel free to suggest any improvement to this solution.