Search code examples
c#linqreflectionlambdatraversal

Traverse an expression tree and extract parameters


I'm writing a kind of a mapping tool. I have a method that looks like this (simplified):

   public void RegisterMapping<TTarget, TSource>(string propertyName, 
                                                 Expression<Func<TSource, object>> memberMap)

The memberMap is an expression defining how to transform a property from TSource to TTarget. For the business logic, I need to extract all references to properties of TSource from it. For example, from

x => x.Customers.Where(c => c.Orders.Any())

I would like to get Customers, and from

x => x.FirstName + " " + x.LastName

FirstName and LastName (could be as string[], PropertyInfo is trivial to convert to).

How would I do this? My first approach was to traverse the tree manually, checking the node type and inspecting different properties depending on the node type (e.g. Operand for unary expressions, Arguments for a function call) to determine of any of these are a property of TSource. Then I discovered the expression kind list and I gave up -- even if I support only the most common kinds, it's still lots of work. Then I found the ExpressionVisitor. It looks better, but it's still a lot of work to override the visitor methods and I'd like to know if there's an other option, using perhaps a more specialised framework, before I devote my time to this.


Solution

  • I think as you've said that using ExpressionVisitor works out to be a good approach. You don't need to implement all the Visit... methods as they already have a default implementation. From what I understood what you want is to find all property accesses of a certain type inside a lambda function

    public class MemberAccessVisitor : ExpressionVisitor
    {
        private readonly Type declaringType;
        private IList<string> propertyNames = new List<string>();
    
        public MemberAccessVisitor(Type declaringType)
        {
            this.declaringType = declaringType;
        }
    
        public IEnumerable<string> PropertyNames { get { return propertyNames; } }
    
        public override Expression Visit(Expression expr)
        {
            if (expr != null && expr.NodeType == ExpressionType.MemberAccess)
            {
                var memberExpr = (MemberExpression)expr;
                if (memberExpr.Member.DeclaringType == declaringType)
                {
                    propertyNames.Add(memberExpr.Member.Name);
                }
            }
    
            return base.Visit(expr);
        }
    }
    

    This could be further improved to what you want by checking the member is a property and also to get PropertyInfo rather than strings

    It could be used as follows:

    var visitor = new MemberAccessVisitor(typeof(TSource));
    
    visitor.Visit(memberMap);
    
    var propertyNames = visitor.PropertyNames;