Search code examples
c#.netalgorithmoptimizationexpression-trees

Efficiently eliminate common sub-expressions in .NET Expression Tree


I've written a DSL and a compiler that generates a .NET expression tree from it. All expressions within the tree are side-effect-free and the expression is guaranteed to be a "non-statement" expression (no locals, loops, blocks etc.). (Edit: The tree may include literals, property accesses, standard operators and function calls - which may be doing fancy things like memoization inside, but are externally side-effect free).

Now I would like to perform the "Common sub-expression elimination" optimization on it.

For example, given a tree corresponding to the C# lambda:

foo =>      (foo.Bar * 5 + foo.Baz * 2 > 7) 
         || (foo.Bar * 5 + foo.Baz * 2 < 3)  
         || (foo.Bar * 5 + 3 == foo.Xyz)

...I would like to generate the tree-equivalent of (ignore the fact that some of the short-circuiting semantics are being ignored):

foo =>
{
     var local1 = foo.Bar * 5;

     // Notice that this local depends on the first one.        
     var local2 = local1 + foo.Baz * 2; 

     // Notice that no unnecessary locals have been generated.
     return local2 > 7 || local2 < 3 || (local1 + 3 == foo.Xyz);
}

I'm familiar with writing expression-visitors, but the algorithm for this optimization isn't immediately obvious to me - I could of course find "duplicates" within a tree, but there's obviously some trick to analyzing the dependencies within and between sub-trees to eliminate sub-expressions efficiently and correctly.

I looked for algorithms on Google but they seem quite complicated to implement quickly. Also, they seem very "general" and don't necessarily take the simplicity of the trees I have into account.


Solution

  • You're correct in noting this is not a trivial problem.

    The classical way that compilers handle it is a Directed Acyclic Graph (DAG) representation of the expression. The DAG is built in the same manner as the abstract syntax tree (and can be built by traversing the AST - perhaps a job for the expression visitor; I don't know much of C# libraries), except that a dictionary of previously emitted subgraphs is maintained. Before generating any given node type with given children, the dictionary is consulted to see if one already exists. Only if this check fails is a new one created, then added to the dictionary.

    Since now a node may descend from multiple parents, the result is a DAG.

    Then the DAG is traversed depth first to generate code. Since common sub-expressions are now represented by a single node, the value is only computed once and stored in a temp for other expressions emitted later in the code generation to use. If the original code contains assignments, this phase gets complicated. Since your trees are side-effect free, the DAG ought to be the most straightforward way to solve your problem.

    As I recall, the coverage of DAGs in the Dragon book is particularly nice.

    As others have noted, if your trees will ultimately be compiled by an existing compiler, it's kind of futile to redo what's already there.

    Addition

    I had some Java code laying around from a student project (I teach) so hacked up a little example of how this works. It's too long to post, but see the Gist here.

    Running it on your input prints the DAG below. The numbers in parens are (unique id, DAG parent count). The parent count is needed to decide when to compute the local temp variables and when to just use the expression for a node.

    Binary OR (27,1)
      lhs:
        Binary OR (19,1)
          lhs:
            Binary GREATER (9,1)
              lhs:
                Binary ADD (7,2)
                  lhs:
                    Binary MULTIPLY (3,2)
                      lhs:
                        Id 'Bar' (1,1)
                      rhs:
                        Number 5 (2,1)
                  rhs:
                    Binary MULTIPLY (6,1)
                      lhs:
                        Id 'Baz' (4,1)
                      rhs:
                        Number 2 (5,1)
              rhs:
                Number 7 (8,1)
          rhs:
            Binary LESS (18,1)
              lhs:
                ref to Binary ADD (7,2)
              rhs:
                Number 3 (17,2)
      rhs:
        Binary EQUALS (26,1)
          lhs:
            Binary ADD (24,1)
              lhs:
                ref to Binary MULTIPLY (3,2)
              rhs:
                ref to Number 3 (17,2)
          rhs:
            Id 'Xyz' (25,1)
    

    Then it generates this code:

    t3 = (Bar) * (5);
    t7 = (t3) + ((Baz) * (2));
    return (((t7) > (7)) || ((t7) < (3))) || (((t3) + (3)) == (Xyz));
    

    You can see that the temp var numbers correspond to DAG nodes. You could make the code generator more complex to get rid of the unnecessary parentheses, but I'll leave that for others.