Search code examples
c#entity-frameworklinqexpression-trees

Get Non-Static MethodInfo for IEnumerable<T>.First() (Or make the static method work with EF)


I have a method, GetSearchExpression, defined as:

    private Expression<Func<T, bool>> GetSearchExpression(
        string targetField, ExpressionType comparison, object value, IEnumerable<EnumerableResultQualifier> qualifiers = null);

At a high level, the method takes in a Field or Property (such as Order.Customer.Name), a comparison type (like Expression.Equals), and a value (like "Billy"), then returns a lambda expression suitable for input to a Where statement o => o.Customer.Name == "Billy"}.

Recently, I discovered an issue. Sometimes, the field I need is actually the field of an item in a collection (like Order.StatusLogs.First().CreatedDate).

I feel like that should be easy. The code that creates the left side of the expression (above, o => o.Customer.Name) is as follows:

var param = Expression.Parameter(typeof(T), "t");
Expression left = null;
//turn "Order.Customer.Name" into List<string> { "Customer", "Name" }
var deQualifiedFieldName = DeQualifyFieldName(targetField, typeof(T));

//loop through each part and grab the specified field or property
foreach (var part in deQualifiedFieldName)
    left = Expression.PropertyOrField(left == null ? param : left, part);

It seems like I should be able to revise this to check if the field/property exists, and if not, try to call a method by that name instead. It would look like this:

var param = Expression.Parameter(typeof(T), "t");
Expression left = null;
var deQualifiedFieldName = DeQualifyFieldName(targetField, typeof(T));
var currentType = typeof(T);
foreach (var part in deQualifiedFieldName)
{
    //this gets the Type of the current "level" we're at in the hierarchy passed via TargetField
    currentType = SingleLevelFieldType(currentType, part);
    if (currentType != null) //if the field/property was found
    {
        left = Expression.PropertyOrField(left == null ? param : left, part);                    
    }
    else
    {   //if the field or property WASN'T found, it might be a method                    
        var method = currentType.GetMethod(part, Type.EmptyTypes); //doesn't accept parameters
        left = Expression.Call(left, method);
        currentType = method.ReturnType;
    }                
}

The problem is that statement near the end (var method currentType.GetMethod(part, Type.EmptyTypes);). Turns out "First" and "Last" don't exist for IEnumerable objects, so I get a null exception when I try to use my Method object. In fact, the only way I can EVER them to show up in a GetMethod() call is by calling typeof(Enumerable).GetMethod(). That's useless of course, because then I get a static method in return rather than the instance method I need.

As a side-note: I tried using the static method, but Entity Framework throws a fit and won't accept it as part of the lambda.

I need help getting the instance MethodInfo of IEnumerable.First() & Last(). Please help!


Solution

  • Thank you to Marc and Ivan for their input. They deserve credit as without their help I would have spent much longer finding a solution. However, as neither answer solved the issue I was having, I'm posting the solution that worked for me (successfully applying criteria as well as successfully querying against an EF data source):

        private Expression<Func<T, bool>> GetSearchExpression(string targetField, ExpressionType comparison, object value, string enumMethod)
        {
            return (Expression<Func<T, bool>>)MakePredicate(DeQualifyFieldName(targetField, typeof(T)), comparison, value, enumMethod);
        }
    
        private LambdaExpression MakePredicate(string[] memberNames, ExpressionType comparison, object value, string enumMethod = "Any")
        {
            //create parameter for inner lambda expression
            var parameter = Expression.Parameter(typeof(T), "t");
            Expression left = parameter;
    
            //Get the value against which the property/field will be compared
            var right = Expression.Constant(value);
    
            var currentType = typeof(T);
            for (int x = 0; x < memberNames.Count(); x++)
            {
                string memberName = memberNames[x];
                if (FieldExists(currentType, memberName))
                {
                    //assign the current type member type 
                    currentType = SingleLevelFieldType(currentType, memberName);
                    left = Expression.PropertyOrField(left == null ? parameter : left, memberName);
    
                    //mini-loop for non collection objects
                    if (!currentType.IsGenericType || (!(currentType.GetGenericTypeDefinition() == typeof(IEnumerable<>) ||
                                                         currentType.GetGenericTypeDefinition() == typeof(ICollection<>))))
                        continue;
    
                    ///Begin loop for collection objects -- this section can only run once
    
                    //get enum method
                    if (enumMethod.Length < 2) throw new Exception("Invalid enum method target.");
                    bool negateEnumMethod = enumMethod[0] == '!';
                    string methodName = negateEnumMethod ? enumMethod.Substring(1) : enumMethod;
    
                    //get the interface sub-type
                    var itemType = currentType.GetInterfaces()
                                              .Single(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IEnumerable<>))
                                              .GetGenericArguments()[0];
    
                    //generate lambda for single item
                    var itemPredicate = MakeSimplePredicate(itemType, memberNames[++x], comparison, value);
    
                    //get method call
                    var staticMethod = typeof(Enumerable).GetMember(methodName).OfType<MethodInfo>()
                                                         .Where(m => m.GetParameters().Length == 2)
                                                         .First()
                                                         .MakeGenericMethod(itemType);
    
                    //generate method call, then break loop for return
                    left = Expression.Call(null, staticMethod, left, itemPredicate);
                    right = Expression.Constant(!negateEnumMethod);
                    comparison = ExpressionType.Equal;
                    break;
                }
            }
    
            //build the final expression
            var binaryExpression = Expression.MakeBinary(comparison, left, right);
            return Expression.Lambda<Func<T, bool>>(binaryExpression, parameter);
        }
    
        static LambdaExpression MakeSimplePredicate(Type inputType, string memberName, ExpressionType comparison, object value)
        {
            var parameter = Expression.Parameter(inputType, "t");
            Expression left = Expression.PropertyOrField(parameter, memberName);
            return Expression.Lambda(Expression.MakeBinary(comparison, left, Expression.Constant(value)), parameter);
        }
    
        private static Type SingleLevelFieldType(Type baseType, string fieldName)
        {
            Type currentType = baseType;
            MemberInfo match = (MemberInfo)currentType.GetField(fieldName) ?? currentType.GetProperty(fieldName);
            if (match == null) return null;
            return GetFieldOrPropertyType(match);
        }
    
        public static Type GetFieldOrPropertyType(MemberInfo field)
        {
            return field.MemberType == MemberTypes.Property ? ((PropertyInfo)field).PropertyType : ((FieldInfo)field).FieldType;
        }
    
        /// <summary>
        /// Remove qualifying names from a target field.  For example, if targetField is "Order.Customer.Name" and
        /// targetType is Order, the de-qualified expression will be "Customer.Name" split into constituent parts
        /// </summary>
        /// <param name="targetField"></param>
        /// <param name="targetType"></param>
        /// <returns></returns>
        public static string[] DeQualifyFieldName(string targetField, Type targetType)
        {
            return DeQualifyFieldName(targetField.Split('.'), targetType);
        }
    
        public static string[] DeQualifyFieldName(string[] targetFields, Type targetType)
        {
            var r = targetFields.ToList();
            foreach (var p in targetType.Name.Split('.'))
                if (r.First() == p) r.RemoveAt(0);
            return r.ToArray();
        }
    

    I included related methods in case someone actually needs to sort through this at some point. :)

    Thanks again!