I want to remove any cast expression in an expression tree. We can assume that the cast is redundant.
For example, both these expressions:
IFoo t => ((t as Foo).Bar as IExtraBar).Baz;
IFoo t => ((IExtraBar)(t as Foo).Bar).Baz;
become this:
IFoo t => t.Bar.Baz
How can you accomplish this?
The sample below illustrates a pretty simple scenario. However, my coded fails with an exception:
Unhandled Exception: System.ArgumentException: Property 'IBar Bar' is not defined for type 'ExpressionTest.Program+IFoo'
using System;
using System.Linq.Expressions;
namespace ExpressionTest
{
class Program
{
public interface IBar
{
string Baz { get; }
}
public interface IExtraBar : IBar
{
}
public interface IFoo
{
IBar Bar { get; }
}
public class Foo : IFoo
{
public IBar Bar { get; }
}
static void Main(string[] args)
{
Expression<Func<IFoo, string>> expr = t => ((t as Foo).Bar as IExtraBar).Baz;
Expression<Func<IFoo, string>> expr2 = t => ((IExtraBar)(t as Foo).Bar).Baz;
// Wanted: IFoo t => t.Bar.Baz
var visitor = new CastRemoverVisitor();
visitor.Visit(expr);
Console.WriteLine(visitor.Expression.ToString());
}
public class CastRemoverVisitor : ExpressionVisitor
{
public Expression Expression { get; private set; }
public override Expression Visit(Expression node)
{
Expression ??= node;
return base.Visit(node);
}
protected override Expression VisitUnary(UnaryExpression node)
{
Expression = node.Operand;
return Visit(node.Operand);
}
}
}
}
The accepted answer pinpoints the use of Expression.MakeMemberAccess and some interface tricks. We can "improve" the code a bit, to use an interfaced PropertyInfo
instead of going through the interfaced getter. I ended up with the following:
public class CastRemoverVisitor : ExpressionVisitor
{
protected override Expression VisitUnary(UnaryExpression node)
{
return node.IsCastExpression() ? Visit(node.Operand) : base.VisitUnary(node);
}
protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression is UnaryExpression unaryExpression &&
unaryExpression.IsCastExpression())
{
var propertyInfo = node.Member.ToInterfacedProperty();
if (propertyInfo != null)
{
return base.Visit(
Expression.MakeMemberAccess(
unaryExpression.Operand,
propertyInfo
));
}
}
return base.VisitMember(node);
}
}
// And some useful extension methods...
public static class MemberInfoExtensions
{
public static MemberInfo ToInterfacedProperty(this MemberInfo member)
{
var interfaces = member.DeclaringType!.GetInterfaces();
var mi = interfaces.Select(i => i.GetProperty(member.Name))
.FirstOrDefault(p => p != null);
return mi;
}
}
public static class ExpressionExtensions
{
public static bool IsCastExpression(this Expression expression) =>
expression.NodeType == ExpressionType.TypeAs ||
expression.NodeType == ExpressionType.Convert;
}
And then we use it like this:
var visitor = new CastRemoverVisitor();
var cleanExpr = visitor.Visit(expr);
Console.WriteLine(cleanExpr.ToString());
First off, nice repro. Thank you.
The problem is that the property access (t as Foo).Bar
is calling the getter for Foo.Bar
, and not the getter for IFoo.Bar
(yes, those are different things with different MethodInfos).
You can see this by overriding VisitMember
and see the MethodInfo
being passed.
However, an approach like this seems to work. We have to unwrap things at the point of the member access, since we can only proceed if we can find an equivalent member to access on the uncasted type:
public class CastRemoverVisitor : ExpressionVisitor
{
protected override Expression VisitMember(MemberExpression node)
{
if (node.Expression is UnaryExpression { NodeType: ExpressionType.TypeAs or ExpressionType.Convert, Operand: var operand } &&
node.Member is PropertyInfo propertyInfo &&
operand.Type.IsInterface)
{
// Is this just inheriting a type from a base interface?
// Get rid of the cast, and just call the property on the uncasted member
if (propertyInfo.DeclaringType == operand.Type)
{
return base.Visit(Expression.MakeMemberAccess(operand, propertyInfo));
}
// Is node.Expression a concrete type, which implements this interface method?
var methodInfo = GetInterfaceMethodInfo(operand.Type, node.Expression.Type, propertyInfo.GetMethod);
if (methodInfo != null)
{
return base.Visit(Expression.Call(operand, methodInfo));
}
}
return base.VisitMember(node);
}
private static MethodInfo GetInterfaceMethodInfo(Type interfaceType, Type implementationType, MethodInfo implementationMethodInfo)
{
if (!implementationType.IsClass)
return null;
var map = implementationType.GetInterfaceMap(interfaceType);
for (int i = 0; i < map.InterfaceMethods.Length; i++)
{
if (map.TargetMethods[i] == implementationMethodInfo)
{
return map.InterfaceMethods[i];
}
}
return null;
}
}
I'm sure there cases which will break this (fields come to mind, and I know GetInterfaceMap
doesn't play well with generics in some situations), but it's a starting point.