Search code examples
c#linqlambdaf#expression-trees

How to 'unquote' when creating Expression tree from lambda?


Let's suppose I have some function c that return Expression:

Func<int, Expression<Func<int>>> c = (int a) => () => a + 3;

Now I want to create another Expression, but during its creation I'd like to call the function c and embed its result as the part of new expression:

Expression<Func<int>> d = () => 2 + c(3);

I can't do this way because it will interpret c(3) like a functions call to be converted to expression and I'll get the error that I cant add int and Expression<Func<int>>

I'd like d to have a value of:

(Expression<Func<int>>)( () => 2 + 3 + 3 )

I'm also interested in getting this to work on more complex expressions, not just this toy example.

How would you do it in C#?

Alternatively, how would you do it in any other CLR language that I could use in my C# project with as little hassle as possible?


More complex examples:

Func<int, Expression<Func<int>>> c = (int a) => () => a*(a + 3);
Expression<Func<int, int>> d = (x) => 2 + c(3 + x);

3+x should be evaluated just once in resulting expression even though it occurs in body of c in two places.


I have a strong feeling that it cannot be achieved in C# because assigning lambda to Expression is done by the compiler and is sort of compile time const expression literal. It would be akin to making compiler that understands plain string literal "test" understand template string literal "test ${a+b} other" and C# compiler is not at this stage of development yet.

So my main question actually is:

What CLR language supports syntax that would allow me to conveniently build Expression trees embedding parts that are constructed by other functions?

Other possibility is some library that would help me build expression trees in this way using some sorts of run-time compiled templates but I'm guessing this way I'd loose code completion for my expression code.


It seems that F# has ability to 'quote' and 'unquote' (splice) the code:

https://learn.microsoft.com/en-us/dotnet/articles/fsharp/language-reference/code-quotations


Solution

  • For both of your examples this can actually be done with two expression visitors (code is commented):

    static class Extensions {
        public static TResult FakeInvoke<TResult>(this Delegate instance, params object[] parameters)
        {
            // this is not intended to be called directly
            throw new NotImplementedException();
        }
    
        public static TExpression Unwrap<TExpression>(this TExpression exp) where TExpression : Expression {
            return (TExpression) new FakeInvokeVisitor().Visit(exp);
        }
    
        class FakeInvokeVisitor : ExpressionVisitor {
            protected override Expression VisitMethodCall(MethodCallExpression node) {
                // replace FakeInvoke call
                if (node.Method.Name == "FakeInvoke") {
                    // first obtain reference to method being called (so, for c.FakeInvoke(...) that will be "c")
                    var func = (Delegate)Expression.Lambda(node.Arguments[0]).Compile().DynamicInvoke();
                    // explore method argument names and types
                    var argumentNames = new List<string>();
                    var dummyArguments = new List<object>();
                    foreach (var arg in func.Method.GetParameters()) {
                        argumentNames.Add(arg.Name);
                        // create default value for each argument
                        dummyArguments.Add(arg.ParameterType.IsValueType ? Activator.CreateInstance(arg.ParameterType) : null);
                    }
                    // now, invoke function with default arguments to obtain expression (for example, this one () => a*(a + 3)).
                    // all arguments will have default value (0 in this case), but they are not literal "0" but a reference to "a" member with value 0
                    var exp = (Expression) func.DynamicInvoke(dummyArguments.ToArray());
                    // this is expressions representing what we passed to FakeInvoke (for example expression (x + 3))
                    var argumentExpressions = (NewArrayExpression)node.Arguments[1];
                    // now invoke second visitor
                    exp = new InnerFakeInvokeVisitor(argumentExpressions, argumentNames.ToArray()).Visit(exp);
                    return ((LambdaExpression)exp).Body;
                }
                return base.VisitMethodCall(node);
            }
        }
    
        class InnerFakeInvokeVisitor : ExpressionVisitor {
            private readonly NewArrayExpression _args;
            private readonly string[] _argumentNames;
            public InnerFakeInvokeVisitor(NewArrayExpression args, string[] argumentNames) {
                _args =  args;
                _argumentNames = argumentNames;
            }
            protected override Expression VisitMember(MemberExpression node) {
                // if that is a reference to one of our arguments (for example, reference to "a")
                if (_argumentNames.Contains(node.Member.Name)) {
                    // find related expression
                    var idx = Array.IndexOf(_argumentNames, node.Member.Name);
                    var argument = _args.Expressions[idx];
                    var unary = argument as UnaryExpression;
                    // and replace it. So "a" is replaced with expression "x + 3"
                    return unary?.Operand ?? argument;
                }
                return base.VisitMember(node);
            }
        }
    }
    

    Can be used like this:

    Func<int, Expression<Func<int>>> c = (int a) => () => a * (a + 3);
    Expression<Func<int, int>> d = (x) => 2 + c.FakeInvoke<int>(3 + x);
    d = d.Unwrap(); // this is now "x => (2 + ((3 + x) * ((3 + x) + 3)))"
    

    Simple case:

    Func<int, Expression<Func<int>>> c = (int a) => () => a + 3;
    Expression<Func<int>> d = () => 2 + c.FakeInvoke<int>(3);
    d = d.Unwrap(); // this is now "() => 2 + (3 + 3)
    

    With multiple arguments:

    Func<int, int, Expression<Func<int>>> c = (int a, int b) => () => a * (a + 3) + b;
    Expression<Func<int, int>> d = (x) => 2 + c.FakeInvoke<int>(3 + x, x + 5);
    d = d.Unwrap(); // "x => (2 + (((3 + x) * ((3 + x) + 3)) + (x + 5)))"
    

    Note that FakeInvoke is not type-safe (you should explicitly set return type and arguments and not checked). But that's just for example, in real use you can create many overloads of FakeInvoke, like this:

    public static TResult FakeInvoke<TArg, TResult>(this Func<TArg, Expression<Func<TResult>>> instance, TArg argument) {
            // this is not intended to be called directly
        throw new NotImplementedException();
    }
    

    Code above should be modified a bit to handle such calls correctly (because arguments are now not in single NewArrayExpression), but that's quite easy to do. With such overloads you can just do:

    Expression<Func<int, int>> d = (x) => 2 + c.FakeInvoke(3 + x); // this is type-safe now, you cannot pass non-integer as "3+x", nor you can pass more or less arguments than required.