Search code examples
c#entity-frameworklinqentity-framework-coreexpression-trees

Complex edit of a body Expression<Func<T,bool>>


Summary: I want to know how can I detect specific definitions from the expression's body then change it in the way I want, such as

e.Entity.ListA.Union(e.ListB).Any(...)...

to

e.Entity != null && 
((e.Entity.ListA != null && e.Entity.ListA.Any(...)) 
|| (e.Entity.ListB != null && e.Entity.ListB.Any(...)))

Only using Linq Expression techniques as I see it is the ideal solution.

As a part of writing a clean C# code, I have written a set of predefined expressions and using LinqKit extensions I can combine between them, hence it will extend the dynamism of writing complex expressions easily, until that everything is okay. In addition, I want to use them to filter both IQueryable and IEnumerable cases.

However, as you know, there are some cases where the defined expression will not work in the former or the latter, I successfully have avoided a lot of such problems. Until I came to the case where I made a solution but I am still feeling is not the ideal.

I will first start by showing the problem, then explain the wished solution, in the end, I will share my attempt.

//---
public class AssignmentsEx : BaseEx
{ 
    //.........

    /// <summary>
    /// (e.FreeRoles AND e.RoleClass.Roles) ⊆ ass.AllRoles
    /// </summary>
    public static Expression<Func<T, bool>> RolesInclosedBy<T>(IAssignedInstitution assignedInstitution) where T : class, IAssignedInstitution
    {
        var allStaticRoles = AppRolesStaticData.AdminRolesStr.GetAll();
        var assAllRoles = assignedInstitution.AllRoles.Select(s => s.Name).ToList();
        var hasAllRoles = allStaticRoles.All(assR => assAllRoles.Any(sR => sR == assR));

        if (hasAllRoles)
            return e => true;

// for LINQ to SQL the expression works perfectly as you know 
// the expression will be translated to an SQL code
// for IEnumerable case the nested object Roles with throw null obj ref 
// exception if the RoleClass is null (and this is a healthy case from code execution
// 
       return Expression<Func<T, bool>> whenToEntity = e => e.FreeRoles.Union(e.RoleClass.Roles).All(eR => assAllRoles.Any(assR => assR == eR.Name));
    }

//.........

}

As you see If I use this method to define a list of objects with RoleClass is null or FreeRoles is null it will throw a NullException.

-- the best-expected suggestion I think it will play on three factors:

  • possibility to detect the desired fragment from the expression body

  • modify the fragment to be as needed for the IEnumerable case or vice versa

  • reconstruct and return new expression

this way will help me to keep the method static and modify it via extension method: e.g. ex.WithSplittedUnion()

rather than the traditional way i.e. I am using now as follow

public class AssignmentsEx
{

public LinqExpressionPurpose purpose{get;}
    
public AssignmentsEx(LinqExpressionPurpose purpose) : base(purpose)
    {
          Purpose = purpose
    }

 public Expression<Func<T, bool>> RolesInclosedBy<T>(IAssignedInstitution assignedInstitution) where T : class, IAssignedInstitution
    {
        var allStaticRoles = AppRolesStaticData.AdminRolesStr.GetAll();
        var assAllRoles = assignedInstitution.AllRoles.Select(s => s.Name).ToList();
        var hasAllRoles = allStaticRoles.All(assR => assAllRoles.Any(sR => sR == assR));

        if (hasAllRoles)
            return e => true;

        Expression<Func<T, bool>> whenToObject = e => (e.FreeRoles == null || e.FreeRoles.All(eR => assAllRoles.Any(assR => assR == eR.Name)))
        && (e.RoleClass == null || e.RoleClass.Roles == null || e.RoleClass.Roles.All(eR => assAllRoles.Any(assR => assR == eR.Name)));

        Expression<Func<T, bool>> whenToEntity = e => e.FreeRoles.Union(e.RoleClass.Roles).All(eR => assAllRoles.Any(assR => assR == eR.Name));

        return Purpose switch
        {
            LinqExpressionPurpose.ToEntity => whenToEntity,
            LinqExpressionPurpose.ToObject => whenToObject,
            _ => null,
        };
    }
}

I hope the explanation is clear, thanks in advance


Solution

  • From how I see it, what you need is ExpressionVisitor to traverse and modify ExpressionTree. One thing I would change is the way you call Any. Instead of

    e.Entity != null && 
    ((e.Entity.ListA != null && e.Entity.ListA.Any(...)) 
    || (e.Entity.ListB != null && e.Entity.ListB.Any(...)))
    

    I'd go for

    (
        e.Entity != null && e.Entity.ListA != null && e.Entity.ListB != null
            ? e.Entity.ListA.Union(e.Entity.ListB)
            : e.Entity != null && e.Entity.ListA != null
                ? e.Entity.ListA
                : e.Entity.ListB != null
                    ? e.Entity.ListB
                    : new Entity[0]
    ).Any(...)
    

    I find it easier to construct ExpressionTree and the outcome will be the same.

    Example code:

    public class OptionalCallFix : ExpressionVisitor
    {
        private readonly List<Expression> _conditionalExpressions = new List<Expression>();
        private readonly Type _contextType;
        private readonly Type _entityType;
    
        private OptionalCallFix(Type contextType, Type entityType)
        {
            this._contextType = contextType;
            this._entityType = entityType;
        }
    
        protected override Expression VisitMethodCall(MethodCallExpression node)
        {
            // Replace Queryable.Union(left, right) call with:
            //     left == null && right == null ? new Entity[0] : (left == null ? right : (right == null ? left : Queryable.Union(left, right)))
            if (node.Method.DeclaringType == typeof(Queryable) && node.Method.Name == nameof(Queryable.Union))
            {
                Expression left = this.Visit(node.Arguments[0]);
                Expression right = this.Visit(node.Arguments[1]);
    
                // left == null
                Expression leftIsNull = Expression.Equal(left, Expression.Constant(null, left.Type));
    
                // right == null
                Expression rightIsNull = Expression.Equal(right, Expression.Constant(null, right.Type));
    
                // new Entity[0].AsQueryable()
                Expression emptyArray = Expression.Call
                (
                    typeof(Queryable),
                    nameof(Queryable.AsQueryable),
                    new [] { this._entityType },
                    Expression.NewArrayInit(this._entityType, new Expression[0])
                );
    
                // left == null && right == null ? new Entity[0] : (left == null ? right : (right == null ? left : Queryable.Union(left, right)))
                return Expression.Condition
                (
                    Expression.AndAlso(leftIsNull, rightIsNull),
                    emptyArray,
                    Expression.Condition
                    (
                        leftIsNull,
                        right,
                        Expression.Condition
                        (
                            rightIsNull,
                            left,
                            Expression.Call
                            (
                                typeof(Queryable), 
                                nameof(Queryable.Union), 
                                new [] { this._entityType }, 
                                left, 
                                Expression.Convert(right, typeof(IEnumerable<>).MakeGenericType(this._entityType))
                            )
                        )
                    )
                );
            }
    
            return base.VisitMethodCall(node);
        }
    
        protected override Expression VisitMember(MemberExpression node)
        {
            Expression expression = this.Visit(node.Expression);
    
            // Check if expression should be fixed
            if (this._conditionalExpressions.Contains(expression))
            {
                // replace e.XXX with e == null ? null : e.XXX
                ConditionalExpression condition = Expression.Condition
                (
                    Expression.Equal(expression, Expression.Constant(null, expression.Type)),
                    Expression.Constant(null, node.Type),
                    Expression.MakeMemberAccess(expression, node.Member)
                );
    
                // Add fixed expression to the _conditionalExpressions list
                this._conditionalExpressions.Add(condition);
    
                return condition;
            }
    
            return base.VisitMember(node);
        }
    
        protected override Expression VisitParameter(ParameterExpression node)
        {
            if (node.Type == this._contextType)
            {
                // Add ParameterExpression to the _conditionalExpressions list
                // It is used in VisitMember method to check if expression should be fixed this way
                this._conditionalExpressions.Add(node);
            }
    
            return base.VisitParameter(node);
        }
    
        public static IQueryable<TEntity> Fix<TContext, TEntity>(TContext context, in Expression<Func<TContext, IQueryable<TEntity>>> method)
        {
            return ((Expression<Func<TContext, IQueryable<TEntity>>>)new OptionalCallFix(typeof(TContext), typeof(TEntity)).Visit(method)).Compile().Invoke(context);
        }
    }
    

    You can call it like this:

    OptionalCallFix.Fix(context, ctx => ctx.Entity.ListA.Union(ctx.ListB));