Search code examples
c#linqlambdaexpression-treesexpressionvisitor

Case insensitive string compare in LINQ expression


I'm trying to write an ExpressionVisitor to wrap around my LINQ-to-object expressions to automatically make their string comparisons case insensitive, just as they would be in LINQ-to-entities.

EDIT: I DEFINITELY want to use an ExpressionVisitor rather than just applying some custom extension or something to my expression when it is created for one important reason: The expression being passed to my ExpressionVisitor is generated by the ASP.Net Web API ODATA layer, so I don't have control over how it is generated (i.e. I can't lowercase the string it is searching for except from within this ExpressionVisitor).

Has to support LINQ to Entities. Not just extension.

Here's what I have so far. It looks for a call to "Contains" on a string and then calls ToLower on any member access inside that expression.

However, it's not working. If I view the expressions after my changes, it looks correct to me, so I'm not sure what I could be doing wrong.

public class CaseInsensitiveExpressionVisitor : ExpressionVisitor
{

    protected override Expression VisitMember(MemberExpression node)
    {
        if (insideContains)
        {
            if (node.Type == typeof (String))
            {
                var methodInfo = typeof (String).GetMethod("ToLower", new Type[] {});
                var expression = Expression.Call(node, methodInfo);
                return expression;
            }
        }
        return base.VisitMember(node);
    }

    private Boolean insideContains = false;
    protected override Expression VisitMethodCall(MethodCallExpression node)
    {
        if (node.Method.Name == "Contains")
        {
            if (insideContains) throw new NotSupportedException();
            insideContains = true;
            var result = base.VisitMethodCall(node);
            insideContains = false;
            return result;
        }
        return base.VisitMethodCall(node);
    }

If I set a breakpoint on the "return expression" line in the VisitMember method and then do a "ToString" on the "node" and "expression" variables, the break point gets hit twice, and here's what the two sets of values are:

First hit:

node.ToString()
"$it.LastName"
expression.ToString()
"$it.LastName.ToLower()"

Second hit:

node.ToString()
"value(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty"
expression.ToString()
"value(System.Web.Http.OData.Query.Expressions.LinqParameterContainer+TypedLinqParameterContainer`1[System.String]).TypedProperty.ToLower()"

I don't know enough about expressions to figure out what I'm doing wrong at this point. Any ideas?


Solution

  • I made a sample app from your code and it seems working:

        public class Test
    {
        public string Name;
    }
    public class CaseInsensitiveExpressionVisitor : ExpressionVisitor
    {
    
        protected override Expression VisitMember(MemberExpression node)
        {
            if (insideContains)
            {
                if (node.Type == typeof (String))
                {
                    var methodInfo = typeof (String).GetMethod("ToLower", new Type[] {});
                    var expression = Expression.Call(node, methodInfo);
                    return expression;
                }
            }
            return base.VisitMember(node);
        }
    
        private Boolean insideContains = false;
    
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            if (node.Method.Name == "Contains")
            {
                if (insideContains) throw new NotSupportedException();
                insideContains = true;
                var result = base.VisitMethodCall(node);
                insideContains = false;
                return result;
            }
            return base.VisitMethodCall(node);
        }
    }
    
    class Program
    {
        static void Main(string[] args)
        {
            Expression <Func<Test, bool>> expr = (t) => t.Name.Contains("a");
            var  expr1 = (Expression<Func<Test, bool>>) new CaseInsensitiveExpressionVisitor().Visit(expr);
            var test = new[] {new Test {Name = "A"}};
            var length = test.Where(expr1.Compile()).ToArray().Length;
            Debug.Assert(length == 1);
            Debug.Assert(test.Where(expr.Compile()).ToArray().Length == 0);
    
        }
    }