Search code examples
linqexpression-treesentity-framework-6interceptor

Intercept all EF6 Linq queries


I have a function which I want to run on every executed Linq query on a DbContext to modify the expression tree before execution. I've been having a look at the IDbCommandTreeInterceptor interface but that doesn't seem to provide an expression tree (which I suppose is understandable since it may not have been a Linq query by the time it gets to this point).

Is there any way I can intercept and modify all expressions before execution?

nb. This has to be Linq tree modification because I have already built a framework for modifying Linq trees which was originally for Linq to SQL.


Solution

  • Creating a proxy for the LINQ provider to intercept every LINQ expression execution (as suggested in the comments) is still a good solution. In fact, I'm playing around with this stuff within this project, which explicitly supports EF6 including EF6 async queries. You can create a standard .NET ExpressionVisitor to make an interception:

    intercepted = query.Rewrite(new MyInterceptor());
    

    But the question also demands "to run on every executed Linq query on a DbContext", and that will be the tricky part. One approach can be some kind of abstraction of DbContext / DbSet, so your code does not directly access DbSet objects. And inside the implementation of this abstraction the interception can happen...

    Another approach (and I think that answers this question best) would be a proxy for DbSet, which calls the LINQ proxy for queries, which enables interception. First, we have to inherit from DbSet:

    public class DbSetProxy<TEntity> : DbSet<TEntity>,
                                       IQueryable<TEntity>,
                                       IDbAsyncEnumerable<TEntity>
        where TEntity : class
    {
        private readonly DbSet<TEntity> set;
        private readonly DbQuery<TEntity> query;
    
        private readonly IQueryable<TEntity> intercepted;
    
        public DbSetProxy(DbSet<TEntity> set)
            : this(set, set)
        {
        }
    
        public DbSetProxy(DbSet<TEntity> set, DbQuery<TEntity> query)
        {
            this.set = set;
            this.query = query;
    
            // use NeinLinq or any other LINQ proxy library
            intercepted = query.Rewrite(new MyInterceptor());
        }
    }
    

    Then, it's necessary to overwrite all the members to call the actual DbSet for non query stuff:

    (Note: It's unfortunately necessary to overwrite every DbSet member, because inheriting DbSet is only designed for test stubs. Thus, just inheriting DbSet breaks the DbSet.)

    public override DbQuery<TEntity> AsNoTracking()
    {
        return new DbSetProxy<TEntity>(set, query.AsNoTracking());
    }
    
    public override DbQuery<TEntity> AsStreaming()
    {
        return new DbSetProxy<TEntity>(set, query.AsStreaming());
    }
    
    public override DbQuery<TEntity> Include(string path)
    {
        return new DbSetProxy<TEntity>(set, query.Include(path));
    }
    
    public override TEntity Add(TEntity entity)
    {
        return set.Add(entity);
    }
    
    public override IEnumerable<TEntity> AddRange(IEnumerable<TEntity> entities)
    {
        return set.AddRange(entities);
    }
    
    public override TEntity Attach(TEntity entity)
    {
        return set.Attach(entity);
    }
    
    public override TEntity Create()
    {
        return set.Create();
    }
    
    public override TDerivedEntity Create<TDerivedEntity>()
    {
        return set.Create<TDerivedEntity>();
    }
    
    public override TEntity Find(params object[] keyValues)
    {
        return set.Find(keyValues);
    }
    
    public override Task<TEntity> FindAsync(params object[] keyValues)
    {
        return set.FindAsync(keyValues);
    }
    
    public override Task<TEntity> FindAsync(CancellationToken cancellationToken, params object[] keyValues)
    {
        return set.FindAsync(cancellationToken, keyValues);
    }
    
    public override TEntity Remove(TEntity entity)
    {
        return set.Remove(entity);
    }
    
    public override IEnumerable<TEntity> RemoveRange(IEnumerable<TEntity> entities)
    {
        return set.RemoveRange(entities);
    }
    
    public override DbSqlQuery<TEntity> SqlQuery(string sql, params object[] parameters)
    {
        return set.SqlQuery(sql, parameters);
    }
    
    public override ObservableCollection<TEntity> Local
    {
        get { return set.Local; }
    }
    
    public override bool Equals(object obj)
    {
        return set.Equals(obj);
    }
    
    public override int GetHashCode()
    {
        return set.GetHashCode();
    }
    
    public override string ToString()
    {
        return set.ToString();
    }
    

    Finally, we've to make use of the interception object:

    IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
    {
        return intercepted.GetEnumerator();
    }
    
    IEnumerator IEnumerable.GetEnumerator()
    {
        return intercepted.GetEnumerator();
    }
    
    Type IQueryable.ElementType
    {
        get { return intercepted.ElementType; }
    }
    
    Expression IQueryable.Expression
    {
        get { return intercepted.Expression; }
    }
    
    IQueryProvider IQueryable.Provider
    {
        get { return intercepted.Provider; }
    }
    
    IDbAsyncEnumerator<TEntity> IDbAsyncEnumerable<TEntity>.GetAsyncEnumerator()
    {
        return ((IDbAsyncEnumerable<TEntity>)intercepted).GetAsyncEnumerator();
    }
    
    IDbAsyncEnumerator IDbAsyncEnumerable.GetAsyncEnumerator()
    {
        return ((IDbAsyncEnumerable<TEntity>)intercepted).GetAsyncEnumerator();
    }
    

    And, at last, we can use an ordinary DbContext. We only have to overwrite its Set method to inject our proxy:

    public class MyContext : DbContext
    {
        public DbSet<Entity> Entities { get; set; }
    
        public override DbSet<TEntity> Set<TEntity>()
        {
            return new DbSetProxy<TEntity>(base.Set<TEntity>());
        }
    }