Search code examples
algorithmcompiler-optimizationboolean-algebra

Is there a known algorithm for simplifying a boolean expression with number comparisons?


For example, if I have the expression (A > 5) && (A == 6), that expression can be simplified to just (A == 6), and still have the same behavior for A ∈ ℤ.

I also need it to work with multiple variables, so for instance ((B > 2) && (C == 2)) || ((B > 2) && (C < 2)) should simplify to (B > 2) && (C < 3).

I won't need to compare two unknowns, only unknowns and numbers, and I only need it to work with the operators <, >, and == for numbers, and && and || for expressions (&& being AND and || being OR, of course). All unknowns are integers.

Is there any algorithm that takes such an expression and returns an expression with equal behavior and a minimal amount of operators?

(in my specific case, || operators are preferred over &&)


Solution

  • Here's a slow dynamic programming algorithm along the lines that you were thinking of.

    from collections import defaultdict, namedtuple
    from heapq import heappop, heappush
    from itertools import product
    from math import inf
    
    # Constructors for Boolean expressions. False and True are also accepted.
    Lt = namedtuple("Lt", ["lhs", "rhs"])
    Eq = namedtuple("Eq", ["lhs", "rhs"])
    Gt = namedtuple("Gt", ["lhs", "rhs"])
    And = namedtuple("And", ["lhs", "rhs"])
    Or = namedtuple("Or", ["lhs", "rhs"])
    
    # Variable names. Arbitrary strings are accepted.
    A = "A"
    B = "B"
    C = "C"
    
    # Example formulas.
    first_example = And(Gt(A, 5), Eq(A, 6))
    second_example = Or(And(Gt(B, 2), Eq(C, 2)), And(Gt(B, 2), Lt(C, 2)))
    third_example = Or(And(Gt(A, 1), Gt(B, 1)), And(Gt(A, 0), Gt(B, 2)))
    fourth_example = Or(Lt(A, 6), Gt(A, 5))
    fifth_example = Or(And(Eq(A, 2), Gt(C, 2)), And(Eq(B, 2), Lt(C, 2)))
    
    # Returns a map from each variable to the set of values such that the formula
    # might evaluate differently for variable = value-1 versus variable = value.
    def get_critical_value_sets(formula, result=None):
        if result is None:
            result = defaultdict(set)
        if isinstance(formula, bool):
            pass
        elif isinstance(formula, Lt):
            result[formula.lhs].add(formula.rhs)
        elif isinstance(formula, Eq):
            result[formula.lhs].add(formula.rhs)
            result[formula.lhs].add(formula.rhs + 1)
        elif isinstance(formula, Gt):
            result[formula.lhs].add(formula.rhs + 1)
        elif isinstance(formula, (And, Or)):
            get_critical_value_sets(formula.lhs, result)
            get_critical_value_sets(formula.rhs, result)
        else:
            assert False, str(formula)
        return result
    
    
    # Returns a list of inputs sufficient to compare Boolean combinations of the
    # primitives returned by enumerate_useful_primitives.
    def enumerate_truth_table_inputs(critical_value_sets):
        variables, value_sets = zip(*critical_value_sets.items())
        return [
            dict(zip(variables, values))
            for values in product(*({-inf} | value_set for value_set in value_sets))
        ]
    
    
    # Returns both constants and all single comparisons whose critical value set is
    # a subset of the given ones.
    def enumerate_useful_primitives(critical_value_sets):
        yield False
        yield True
        for variable, value_set in critical_value_sets.items():
            for value in value_set:
                yield Lt(variable, value)
                if value + 1 in value_set:
                    yield Eq(variable, value)
                yield Gt(variable, value - 1)
    
    
    # Evaluates the formula recursively on the given input.
    def evaluate(formula, input):
        if isinstance(formula, bool):
            return formula
        elif isinstance(formula, Lt):
            return input[formula.lhs] < formula.rhs
        elif isinstance(formula, Eq):
            return input[formula.lhs] == formula.rhs
        elif isinstance(formula, Gt):
            return input[formula.lhs] > formula.rhs
        elif isinstance(formula, And):
            return evaluate(formula.lhs, input) and evaluate(formula.rhs, input)
        elif isinstance(formula, Or):
            return evaluate(formula.lhs, input) or evaluate(formula.rhs, input)
        else:
            assert False, str(formula)
    
    
    # Evaluates the formula on the many inputs, packing the values into an integer.
    def get_truth_table(formula, inputs):
        truth_table = 0
        for input in inputs:
            truth_table = (truth_table << 1) + evaluate(formula, input)
        return truth_table
    
    
    # Returns (the number of operations in the formula, the number of Ands).
    def get_complexity(formula):
        if isinstance(formula, bool):
            return (0, 0)
        elif isinstance(formula, (Lt, Eq, Gt)):
            return (1, 0)
        elif isinstance(formula, And):
            ops_lhs, ands_lhs = get_complexity(formula.lhs)
            ops_rhs, ands_rhs = get_complexity(formula.rhs)
            return (ops_lhs + 1 + ops_rhs, ands_lhs + 1 + ands_rhs)
        elif isinstance(formula, Or):
            ops_lhs, ands_lhs = get_complexity(formula.lhs)
            ops_rhs, ands_rhs = get_complexity(formula.rhs)
            return (ops_lhs + 1 + ops_rhs, ands_lhs + ands_rhs)
        else:
            assert False, str(formula)
    
    
    # Formula compared by complexity.
    class HeapItem:
        __slots__ = ["_complexity", "formula"]
    
        def __init__(self, formula):
            self._complexity = get_complexity(formula)
            self.formula = formula
    
        def __lt__(self, other):
            return self._complexity < other._complexity
    
        def __le__(self, other):
            return self._complexity <= other._complexity
    
        def __eq__(self, other):
            return self._complexity == other._complexity
    
        def __ne__(self, other):
            return self._complexity != other._complexity
    
        def __ge__(self, other):
            return self._complexity >= other._complexity
    
        def __gt__(self, other):
            return self._complexity > other._complexity
    
    
    # Like heapq.merge except we can add iterables dynamically.
    class Merge:
        __slots__ = ["_heap", "_iterable_count"]
    
        def __init__(self):
            self._heap = []
            self._iterable_count = 0
    
        def update(self, iterable):
            iterable = iter(iterable)
            try:
                value = next(iterable)
            except StopIteration:
                return
            heappush(self._heap, (value, self._iterable_count, iterable))
            self._iterable_count += 1
    
        def __iter__(self):
            return self
    
        def __next__(self):
            if not self._heap:
                raise StopIteration
            value, index, iterable = heappop(self._heap)
            try:
                next_value = next(iterable)
            except StopIteration:
                return value
            heappush(self._heap, (next_value, index, iterable))
            return value
    
    
    class Combinations:
        __slots__ = ["_op", "_formula", "_best_formulas", "_i", "_n"]
    
        def __init__(self, op, formula, best_formulas):
            self._op = op
            self._formula = formula
            self._best_formulas = best_formulas
            self._i = 0
            self._n = len(best_formulas)
    
        def __iter__(self):
            return self
    
        def __next__(self):
            if self._i >= self._n:
                raise StopIteration
            formula = self._op(self._formula, self._best_formulas[self._i])
            self._i += 1
            return HeapItem(formula)
    
    
    # Returns the simplest equivalent formula, breaking ties in favor of fewer Ands.
    def simplify(target_formula):
        critical_value_sets = get_critical_value_sets(target_formula)
        inputs = enumerate_truth_table_inputs(critical_value_sets)
        target_truth_table = get_truth_table(target_formula, inputs)
        best = {}
        merge = Merge()
        for formula in enumerate_useful_primitives(critical_value_sets):
            merge.update([HeapItem(formula)])
        best_formulas = []
        for item in merge:
            if target_truth_table in best:
                return best[target_truth_table]
            formula = item.formula
            truth_table = get_truth_table(formula, inputs)
            if truth_table in best:
                continue
            n = len(best_formulas)
            for op in [And, Or]:
                merge.update(Combinations(op, formula, best_formulas))
            best[truth_table] = formula
            best_formulas.append(formula)
    
    
    print(simplify(first_example))
    print(simplify(second_example))
    print(simplify(third_example))
    print(simplify(fourth_example))
    print(simplify(fifth_example))
    

    Output:

    Eq(lhs='A', rhs=6)
    And(lhs=Lt(lhs='C', rhs=3), rhs=Gt(lhs='B', rhs=2))
    And(lhs=And(lhs=Gt(lhs='B', rhs=1), rhs=Gt(lhs='A', rhs=0)), rhs=Or(lhs=Gt(lhs='B', rhs=2), rhs=Gt(lhs='A', rhs=1)))
    True
    Or(lhs=And(lhs=Eq(lhs='B', rhs=2), rhs=Lt(lhs='C', rhs=2)), rhs=And(lhs=Gt(lhs='C', rhs=2), rhs=Eq(lhs='A', rhs=2)))