Search code examples

expanding math expression ast

I parsed a expression like a*(b+c) to an AST and it finally becomes: enter image description here

I'm trying to expand the expression it finally becomes ab + ac, but I have no idea.

I would like to found an algorithm to expand the expression, or maybe a library to do it, preferably for Java.


  • Try this.

    interface AST { AST expand(); }
    public class Var implements AST {
        public final String name;
        public Var(String name) { = name; }
        @Override public AST expand() { return this; }
        @Override public String toString() { return name; }
    public abstract class Binary implements AST {
        public final AST left, right;
        Binary(AST left, AST right) { this.left = left; this.right = right; }
    public class Plus extends Binary {
        public Plus(AST left, AST right) { super(left, right); }
        @Override public AST expand() { return this; }
        @Override public String toString() { return "(%s + %s)".formatted(left, right); }
    public class Mult extends Binary {
        public Mult(AST left, AST right) { super(left, right); }
        public AST expand() {
            AST l = left.expand(), r = right.expand();
            if (l instanceof Plus lp && r instanceof Plus rp)
                return new Plus(
                    new Plus(new Mult(lp.left, rp.left), new Mult(lp.left, rp.right)),
                    new Plus(new Mult(lp.right, rp.left), new Mult(lp.right, rp.right)));
            else if (l instanceof Plus lp)
                return new Plus(new Mult(lp.left, r), new Mult(lp.right, r));
            else if (r instanceof Plus rp)
                return new Plus(new Mult(l, rp.left), new Mult(l, rp.right));
                return new Mult(l, r);
        @Override public String toString() { return "(%s * %s)".formatted(left, right); }


        AST e = new Mult(new Var("a"), new Plus(new Var("b"), new Var("c")));
        System.out.println(e + " -> " + e.expand());
        AST f = new Mult(new Plus(new Var("a"), new Var("b")), new Plus(new Var("c"), new Var("d")));
        System.out.println(f + " -> " + f.expand());


    (a * (b + c)) -> ((a * b) + (a * c))
    ((a + b) * (c + d)) -> (((a * c) + (a * d)) + ((b * c) + (b * d)))