Search code examples
c#expression-treesentity-framework-plusentity-framework-extensions

How to intercept and modify bulk updates?


In several places in my code I am doing bulk updates using the handy dandy Z.EntityFramework.Plus extensions, e.g.

await db.Foos
        .Where(f => f.SomeCondition)
        .UpdateAsync(f => new Foo { Field1 = "bar", Field2 = f.Field2 + 1 });

which will update all Foo records where SomeCondition is true, setting Field1 to "bar" and Field2 will be incremented by one.

Now a new requirement has come up, where some tables (but not all) are tracking ModifiedDate. That includes records where I'm doing bulk updates.

So my approach is like this. I have an interface:

public interface ITrackModifiedDate
{
    DateTime ModifiedDate { get; set; }
}

So all my classes that track ModifiedDate can implement ITrackModifiedDate. Then, I write a middle-man extension to intercept the .UpdateAsync() calls:

public static async Task<int> UpdateAsync<T>(this IQueryable<T> queryable, Expression<Func<T, T>> updateFactory)
        where T : class
{
    if (typeof(ITrackModifiedDate).IsAssignableFrom(typeof(T)))
    {
        // TODO Now what?
    }

    return await BatchUpdateExtensions.UpdateAsync(queryable, updateFactory);
}

As you can see, I'm not entirely sure how to modify updateFactory to set ModifiedDate to DateTime.UtcNow, on top of whatever other fields are already being updated.

How to do it?

UPDATE: I'm not averse to changing my extension so that it only accepts T of type ITrackModifiedDate, if that helps, i.e.

public static async Task<int> UpdateAsync<T>(this IQueryable<T> queryable, Expression<Func<T, T>> updateFactory)
    where T : class, ITrackModifiedDate
{
        // TODO what now?

    return await BatchUpdateExtensions.UpdateAsync(queryable, updateFactory);
}

Solution

  • I got it working with following code:

    using System;
    using System.Data.Entity;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Threading.Tasks;
    using Z.EntityFramework.Plus; 
    
    class Program
    {
        static async Task Main(string[] args)
        {
            using (var context = new SomeContext())
            {
                await context
                    .Customers
                    .Where(c => c.Email.Contains("42"))
                    .CustomUpdateAsync((c) => new Customer()
                    {
                        Email = "4242"
                    });
            }
        }
    
    }
    
    public static class Helper
    {
        public static async Task<int> CustomUpdateAsync<T>(this IQueryable<T> queryable, Expression<Func<T, T>> updateFactory)
            where T : class
        {
            var targetType = typeof(T);
            if (typeof(ITrackModifiedDate).IsAssignableFrom(targetType))
            {
                updateFactory = (Expression<Func<T, T>>)new TrackModifiedDateVisitor().Modify(updateFactory);
            }
    
            return await BatchUpdateExtensions.UpdateAsync(queryable, updateFactory);
        }
    }
    
    
    public class TrackModifiedDateVisitor : ExpressionVisitor
    {
        public Expression Modify(Expression expression)
        {
            return Visit(expression);
        }
    
        public override Expression Visit(Expression node)
        {
            if (node is MemberInitExpression initExpression)
            {
                var existingBindings = initExpression.Bindings.ToList();
                var modifiedProperty = initExpression.NewExpression.Type.GetProperty(nameof(ITrackModifiedDate.ModifiedDate));
    
                // it will be `some.ModifiedDate = currentDate`
                var modifiedExpression = Expression.Bind(
                    modifiedProperty,
                    Expression.Constant(DateTime.Now, typeof(DateTime))
                    );
    
                existingBindings.Add(modifiedExpression);
    
                // and then we just generate new MemberInit expression but with additional property assigment
                return base.Visit(Expression.MemberInit(initExpression.NewExpression, existingBindings));
            }
    
            return base.Visit(node);
        }
    }
    
    
    public class SomeContext: DbContext
    {
        public SomeContext()
            : base("Data Source=.;Initial Catalog=TestDb;Integrated Security=SSPI;")
        {
            Database.SetInitializer(new CreateDatabaseIfNotExists<SomeContext>());
        }
    
        public DbSet<Customer> Customers { get; set; }
    }
    
    public class Customer: ITrackModifiedDate
    {
        public int ID { get; set; }
        public string Email { get; set; }
        public DateTime ModifiedDate { get; set; }
    }
    
    public interface ITrackModifiedDate
    {
        DateTime ModifiedDate { get; set; }
    }
    

    Needed part is TrackModifiedDateVisitor class that traverse through updateFactory expression and when it finds MemberInitExpression and updates it. Initially it has list of property assigments, and we generate a new one for ModifiedDate and create new MemberInitExpression with existing assigments plus generated one.

    As result after visitor code is executed - updateFactory would have

    c => new Customer() {Email = "4242", ModifiedDate = 5/16/2019 23:19:00}