Search code examples
javabinary-treesymbolic-math

Clean way to simplify a binary expression tree


The goal of my program is to display the symbolic derivative of a mathematical expression. After creating a new tree that represents the derivative, it is likely that I will be left with redundant terms.

For example, the following tree is not simplified.

Example of binary expression tree

The tree 0 + 5 * (x * 5) can be rewritten as 25 * x

My program uses many, many if and else blocks to reduce the tree by checking for constants multiplied by constants, etc. Then, it rearranges the sub tree accordingly.

Here is a tiny portion of my recursive function that simplifies the tree:

if(root.getVal().equals("*")) {

        if(root.getLeftChild().getVal().equals("1")) {
            return root.getRightChild();
        }
        else if(root.getRightChild().getVal().equals("1")) {
            return root.getLeftChild();
        }
        else if(root.getLeftChild().getVal().equals("0")) {
            return root.getLeftChild();
        }
        else if(root.getRightChild().getVal().equals("0")) {
            return root.getRightChild();
        }
        else if(root.getLeftChild().getVal().equals("*")) {
            if(root.getRightChild().getType().equals("constant")) {
                if(root.getLeftChild().getLeftChild().getType().equals("constant")) { // Ex: (5*x)*6 ==> 30*x
                    int num1 = Integer.parseInt(root.getRightChild().getVal());
                    int num2 = Integer.parseInt(root.getLeftChild().getLeftChild().getVal());
                    OpNode mult = new OpNode("*");
                    mult.setLeftChild(new ConstNode(String.valueOf(num1 * num2)));
                    mult.setRightChild(root.getLeftChild().getRightChild());

                    return mult;
                }
        ...
        ...
        ...
...

The function works great, other than the fact that I need to call it a few times to ensure the tree is fully reduced(incase a reduction opens up another reduction possibility). However, it is 200 lines long and growing, which leads me to believe there must be a much better way to do this.


Solution

  • One typical approach to this problem is the visitor pattern. Any time you need to walk a recursive structure, applying logic at each node which depends on the "type" of the node, this pattern is a good tool to have handy.

    For this specific problem, and specifically in Java, I'd start by representing your expression "abstract syntax tree" more directly as a type hierarchy.

    I've put together a simple example, assuming your AST handles +, -, *, / as well as literal numbers and named variables. I've called my Visitor a Folder---we sometimes use this name for visitor-alikes which replace ("fold") subtrees. (Think: optimization or de-sugaring passes in compilers.)

    The trick to handling the "I need to sometimes repeat simplification" is to do a depth-first traversal: all children get fully simplified before we simplify their parents.

    Here's the example (disclaimer: I hate Java, so I don't promise this is the most "idiomatic" implementation in the language):

    interface Folder {
        // we could use the name "fold" for all of these, overloading on the
        //   argument type, and the dispatch code in each concrete Expression
        //   class would still do the right thing (selecting an overload using
        //   the type of "this") --- but this is a little easier to follow
        Expression foldBinaryOperation(BinaryOperation expr);
        Expression foldUnaryOperation(UnaryOperation expr);
        Expression foldNumber(Number expr);
        Expression foldVariable(Variable expr);
    }
    
    abstract class Expression {
        abstract Expression fold(Folder f);
    
        // logic to build a readable representation for testing
        abstract String repr();
    }
    
    enum BinaryOperator {
        PLUS,
        MINUS,
        MUL,
        DIV,
    }
    
    enum UnaryOperator {
        NEGATE,
    }
    
    class BinaryOperation extends Expression {
        public BinaryOperation(BinaryOperator operator,
                Expression left, Expression right)
        {
            this.operator = operator;
            this.left = left;
            this.right = right;
        }
    
        public BinaryOperator operator;
        public Expression left;
        public Expression right;
    
        public Expression fold(Folder f) {
            return f.foldBinaryOperation(this);
        }
    
        public String repr() {
            // parens for clarity
            String result = "(" + left.repr();
            switch (operator) {
                case PLUS:
                    result += " + ";
                    break;
                case MINUS:
                    result += " - ";
                    break;
                case MUL:
                    result += " * ";
                    break;
                case DIV:
                    result += " / ";
                    break;
            }
            result += right.repr() + ")";
            return result;
        }
    }
    
    class UnaryOperation extends Expression {
        public UnaryOperation(UnaryOperator operator, Expression operand)
        {
            this.operator = operator;
            this.operand = operand;
        }
    
        public UnaryOperator operator;
        public Expression operand;
    
        public Expression fold(Folder f) {
            return f.foldUnaryOperation(this);
        }
    
        public String repr() {
            String result = "";
            switch (operator) {
                case NEGATE:
                    result = "-";
                    break;
            }
            result += operand.repr();
            return result;
        }
    }
    
    class Number extends Expression {
        public Number(double value)
        {
            this.value = value;
        }
    
        public double value;
    
        public Expression fold(Folder f) {
            return f.foldNumber(this);
        }
    
        public String repr() {
            return Double.toString(value);
        }
    }
    
    class Variable extends Expression {
        public Variable(String name)
        {
            this.name = name;
        }
    
        public String name;
    
        public Expression fold(Folder f) {
            return f.foldVariable(this);
        }
    
        public String repr() {
            return name;
        }
    }
    
    // a base class providing "standard" traversal logic (we could have
    //   made Folder abstract and put these there
    class DefaultFolder implements Folder {
        public Expression foldBinaryOperation(BinaryOperation expr) {
            // recurse into both sides of the binary operation
            return new BinaryOperation(
                    expr.operator, expr.left.fold(this), expr.right.fold(this));
        }
    
        public Expression foldUnaryOperation(UnaryOperation expr) {
            // recurse into operand
            return new UnaryOperation(expr.operator, expr.operand.fold(this));
        }
    
        public Expression foldNumber(Number expr) {
            // numbers are "terminal": no more recursive structure to walk
            return expr;
        }
    
        public Expression foldVariable(Variable expr) {
            // another non-recursive expression
            return expr;
        }
    }
    
    class Simplifier extends DefaultFolder {
        public Expression foldBinaryOperation(BinaryOperation expr) {
            // we want to do a depth-first traversal, ensuring that all
            //   sub-expressions are simplified before their parents...
            // ... so begin by invoking the superclass "default"
            //   traversal logic.
            BinaryOperation folded_expr =
                // this cast is safe because we know the default fold
                //   logic never changes the type of the top-level expression
                (BinaryOperation)super.foldBinaryOperation(expr);
    
            // now apply our "shallow" simplification logic on the result
            switch (folded_expr.operator) {
                case PLUS:
                    // x + 0 => x
                    if (folded_expr.right instanceof Number
                            && ((Number)(folded_expr.right)).value == 0)
                        return folded_expr.left;
    
                    // 0 + x => x
                    if (folded_expr.left instanceof Number
                            && ((Number)(folded_expr.left)).value == 0)
                        return folded_expr.right;
                    break;
    
                case MINUS:
                    // x - 0 => x
                    if (folded_expr.right instanceof Number
                            && ((Number)(folded_expr.right)).value == 0)
                        return folded_expr.left;
    
                    // 0 - x => -x
                    if (folded_expr.left instanceof Number
                            && ((Number)(folded_expr.left)).value == 0) {
                        // a weird case: we need to construct a UnaryOperator
                        //   representing -right, then simplify it
                        UnaryOperation minus_right = new UnaryOperation(
                                UnaryOperator.NEGATE, folded_expr.right);
                        return foldUnaryOperation(minus_right);
                    }
                    break;
    
                case MUL:
                    // 1 * x => x
                    if (folded_expr.left instanceof Number
                            && ((Number)(folded_expr.left)).value == 1)
                        return folded_expr.right;
    
                case DIV:
                    // x * 1 => x
                    // x / 1 => x
                    if (folded_expr.right instanceof Number
                            && ((Number)(folded_expr.right)).value == 1)
                        return folded_expr.left;
                    break;
            }
    
            // no rules applied
            return folded_expr;
        }
    
        public Expression foldUnaryOperation(UnaryOperation expr) {
            // as before, go depth-first:
            UnaryOperation folded_expr =
                // see note in foldBinaryOperation about safety here
                (UnaryOperation)super.foldUnaryOperation(expr);
    
            switch (folded_expr.operator) {
                case NEGATE:
                    // --x => x
                    if (folded_expr.operand instanceof UnaryOperation
                            && ((UnaryOperation)folded_expr).operator ==
                               UnaryOperator.NEGATE)
                        return ((UnaryOperation)folded_expr.operand).operand;
    
                    // -(number) => -number
                    if (folded_expr.operand instanceof Number)
                        return new Number(-((Number)(folded_expr.operand)).value);
                    break;
            }
    
            // no rules applied
            return folded_expr;
        }
    
        // we don't need to implement the other two; the inherited defaults are fine
    }
    
    public class Simplify {
        public static void main(String[] args) {
            Simplifier simplifier = new Simplifier();
    
            Expression[] exprs = new Expression[] {
                new BinaryOperation(
                        BinaryOperator.PLUS,
                        new Number(0.0),
                        new Variable("x")
                ),
    
                new BinaryOperation(
                    BinaryOperator.PLUS,
                    new Number(17.3),
                    new UnaryOperation(
                        UnaryOperator.NEGATE,
                        new UnaryOperation(
                            UnaryOperator.NEGATE,
                            new BinaryOperation(
                                BinaryOperator.DIV,
                                new Number(0.0),
                                new Number(1.0)
                            )
                        )
                    )
                ),
            };
    
            for (Expression expr: exprs) {
                System.out.println("Unsimplified: " + expr.repr());
    
                Expression simplified = expr.fold(simplifier);
                System.out.println("Simplified: " + simplified.repr());
            }
        }
    }
    

    And the output:

    > java Simplify
    
    Unsimplified: (0.0 + x)
    Simplified: x
    Unsimplified: (17.3 + --(0.0 / 1.0))
    Simplified: 17.3