Search code examples
c#treeexpressionexpression-treesvisitor-pattern

Exhaustive search / generate each combination of an expression tree


I'm playing with a basic Expression Tree optimiser to build query plans. When parsing a Tree, I can make decisions as to how is "best" to construct it, depending on a weighting I can assign to each operation.

If I have a simple tree, with 2 choices on how to perform an action, I'd like to be able to generate both variations of the tree, and can then compare the weighting of each to see what is the most efficient.

For example, the code below will allow me to construct two variations of the Expression Tree Join operation : one with a MergeJoinExpression and one with a NestedLoopJoinExpression

class Customer
{
        public int Id { get; set; }
}
class Orders
{
        public int Id { get; set; }
        public int CustomerId { get; set; }
}

class MergeJoinExpresion : JoinExpression
{
}

class NestLoopJoinExpresion : JoinExpression
{
}

class Visitor : ExpressionVisitor
{
    public List<Expression> GetPlans(Expression expr)
    {
        // ???
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion

        return base.VisitJoin(join);
    }
}

How can I construct a method that will generate each variation of tree and return them to me?

class Program
{
        static void Main(string[] args)
        {
             var query = from c in customers
                        join o in orders on c.Id equals o.CustomerId
                        select new
                        {
                            CustomerId = c.Id,
                            OrderId = o.Id
                        };


            var plans = new Visitor().GetPlans(query);
        }
}

Can anyone show me how I can modify the Visitor Class GetPlans method to generate these variations?

EDIT - something like:

class Visitor : ExpressionVisitor
{
    private List<Expression> exprs = new List<Expression>();

    public List<Expression> GetPlans(Expression expr)
    {
        Visit(expr);    
        return exprs;
    }

    override VisitJoin(JoinExpression join)
    {
        // For this join, I can return the following (trite example)
        // return MergeJoinExpresion
        // return NestLoopJoinExpresion      
        var choices = new Expression[] { MergeJoinExpresion.Create(join), NestLoopJoinExpresion.Create(join) };

        foreach(var choice in choices)
        {
             var cloned = Cloner.Clone(choice);
             var newTree = base.VisitJoin(cloned);
             exprs.Add(newTree);
        }

        return base.VisitJoin(join);
    }
}

Solution

  • So to start with we'll create a visitor that will just help us extract a list of JoinExpression objects from an Expression:

    internal class FindJoinsVisitor : ExpressionVisitor
    {
        private List<JoinExpression> expressions = new List<JoinExpression>();
        protected override Expression VisitJoin(JoinExpression join)
        {
            expressions.Add(join);
            return base.VisitJoin(join);
        }
        public IEnumerable<JoinExpression> JoinExpressions
        {
            get
            {
                return expressions;
            }
        }
    }
    public static IEnumerable<JoinExpression> FindJoins(
        this Expression expression)
    {
        var visitor = new FindJoinsVisitor();
        visitor.Visit(expression);
        return visitor.JoinExpressions;
    }
    

    Next we'll use the following method, taken from this blog post, to get the Cartesian Product of a sequence of sequences:

    static IEnumerable<IEnumerable<T>> CartesianProduct<T>(
        this IEnumerable<IEnumerable<T>> sequences) 
    { 
        IEnumerable<IEnumerable<T>> emptyProduct = new[] { Enumerable.Empty<T>() }; 
        return sequences.Aggregate( 
            emptyProduct, 
            (accumulator, sequence) => 
                from accseq in accumulator 
                from item in sequence 
                select accseq.Concat(new[] {item})); 
    }
    

    Next we'll create a visitor that takes a sequence of pairs of expressions, and replaces all instances of the first expression in the pair with the second:

    internal class ReplaceVisitor : ExpressionVisitor
    {
        private readonly Dictionary<Expression, Expression> lookup;
        public ReplaceVisitor(Dictionary<Expression, Expression> pairsToReplace)
        {
            lookup = pairsToReplace;
        }
        public override Expression Visit(Expression node)
        {
            if(lookup.ContainsKey(node))
                return base.Visit(lookup[node]);
            else
                return base.Visit(node);
        }
    }
    
    public static Expression ReplaceAll(this Expression expression,
        Dictionary<Expression, Expression> pairsToReplace)
    {
        return new ReplaceVisitor(pairsToReplace).Visit(expression);
    }
    
    public static Expression ReplaceAll(this Expression expression,
        IEnumerable<Tuple<Expression, Expression>> pairsToReplace)
    {
        var lookup = pairsToReplace.ToDictionary(pair => pair.Item1, pair => pair.Item2);
        return new ReplaceVisitor(lookup).Visit(expression);
    }
    

    And finally we put everything together by finding all of the join expressions in our expression, project those out to a sequence of pairs where the JoinExpression is the first item in the pair, and the second is each possible replacement value. From there we can take the Cartesian Product of that to get all of the combinations of pairs of expression replacements. Finally we can project each combination of replacements into the expression that results from actually replacing all of those pairs in the original expression:

    public static IEnumerable<Expression> AllJoinCombinations(Expression expression)
    {
        var combinations = expression.FindJoins()
            .Select(join => new Tuple<Expression, Expression>[]
            {
                Tuple.Create<Expression, Expression>(join, new NestLoopJoinExpresion(join)), 
                Tuple.Create<Expression, Expression>(join, new MergeJoinExpresion(join)),
            })
            .CartesianProduct();
    
        return combinations.Select(combination => expression.ReplaceAll(combination));
    }