Search code examples
c#linqexpressionexpression-trees

Expression Tree: replace expressions like <variable> + 1 / <variable> -1 with increment / detriment operations respectively


I would like to convert () => a - 1 + b + 1 lambda to something like () => a-- + b++ with expression tree.

I implement class ExpressionTreeTransformer.cs that inherits from ExpressionVisitor.cs and override VisitBinary method:

protected override Expression VisitBinary(BinaryExpression node)
        {
            if (!TryGetNumberValueFromExpressionNode(node.Left, out var leftNodeValue))
            {
                return base.VisitBinary(node);
            }

            if (!TryGetNumberValueFromExpressionNode(node.Right, out var rightNodeValue) || rightNodeValue != 1)
            {
                return base.VisitBinary(node);
            }

            var resultedExpression = node.NodeType switch
            {
                ExpressionType.Add => Expression.Increment(Expression.Constant(leftNodeValue)),
                ExpressionType.Subtract => Expression.Decrement(Expression.Constant(leftNodeValue)),
                _ => base.VisitBinary(node)
            };

            return resultedExpression;
        }

It works fine if were applied parentheses like () => (a - 1) + (b + 1), but doesn't if we try () => a - 1 + b + 1

After some investigation i founded out that reason of that is how expression tree build nodes. Without parentheses steps look like:

  1. left(a - 1) + right(b)
  2. left(result of 1 step) + right(1)

Expression nodes are processed in handlers chain:

_expressionHandlers = new MemberExpressionHandler();
_expressionHandlers.SetSuccessor(new ConstantExpressionHandler());

Variables handler:

public class MemberExpressionHandler : AbstractTreeExpressionHandler
    {
        public override bool Handle(Expression expressionNode, out int nodeValue)
        {
            if (expressionNode is MemberExpression memberExpression)
            {
                var constantExpression = memberExpression.Expression as ConstantExpression;
                var field = (FieldInfo)memberExpression.Member;

                if (constantExpression != null)
                {
                    var value = field.GetValue(constantExpression.Value);
                    var isNumber = int.TryParse(value.ToString(), out nodeValue);

                    if (isNumber)
                    {
                        return true;
                    }
                }
            }
            else
            {
                if (_successor != null)
                {
                    return _successor.Handle(expressionNode, out nodeValue);
                }
            }

            nodeValue = 0;

            return false;
        }
    }

Constants handler:

public class ConstantExpressionHandler : AbstractTreeExpressionHandler
    {
        public override bool Handle(Expression expressionNode, out int nodeValue)
        {
            var isConstant = expressionNode is ConstantExpression;
            var isNumber = int.TryParse(((ConstantExpression)expressionNode).Value.ToString(), out nodeValue);

            if (isConstant && isNumber)
            {
                return true;
            }

            return false;
        }
    }

Q: I'm stuck, please share your experience how this task can be solved at right way

P.s results:

  • With parentheses: () => (Decrement(0) + Increment(1))
  • Without: () => ((Decrement(0) + value(ExpressionTreeModule.Program+<>c__DisplayClass0_0).b) + 1)

Solution

  • I think you will need to examine the left hand side when you find a constant add or subtract of 1 and determine if you can hoist the addition/subtraction up to be an increment/decrement. Here is my sample code:

    public class UnaryConstant1 : ExpressionVisitor {
        protected override Expression VisitBinary(BinaryExpression node) {
            if (node.Right is ConstantExpression c && c.Type.IsNumeric() && (Int32)c.Value == 1) {
                if (node.NodeType == ExpressionType.Add || node.NodeType == ExpressionType.Subtract) {
                    if (node.Left is MemberExpression) {
                        if (node.NodeType == ExpressionType.Add)
                            return Expression.Increment(node.Left);
                        else
                            return Expression.Decrement(node.Left);
                    }
                    else if (node.Left is BinaryExpression left && (left.NodeType == ExpressionType.Add || left.NodeType == ExpressionType.Subtract)) {
                        Expression right;
                        if (node.NodeType == ExpressionType.Add)
                            right = Expression.Increment(left.Right);
                        else
                            right = Expression.Decrement(left.Right);
    
                        if (left.NodeType == ExpressionType.Add)
                            return Expression.Add(Visit(left.Left), right);
                        else
                            return Expression.Subtract(Visit(left.Left), right);
                    }
                }
            }
            return base.VisitBinary(node);
        }
    }