Search code examples
c#linqentity-frameworklinq-to-entities

Override LINQ .Include(), possible?


I've been browsing the internet for this issue for quite some time, and I'm getting far fewer results pertaining to overriding LINQ methods. I'm not sure if it can be done, but I'm wondering if someone can either confirm whether this works or not, or suggest an alternative.

The situation is as follows (simplified for this question, of course)

We are using EF6 Code First to build our database. We have added a custom (abstract) base class from which ALL entities derive. This base class implements some fields we use for auditing (creation date, created by, modified date, ...), but we also implement a soft delete system by adding the IsDeleted (bool) property in the base class.

As far as our application is aware, items with IsDeleted == true must never be returned.

The DataModel is as follows (again, simplified)

Company
    ---> 1 Company has many Departments
    Department
        ---> 1 Department has many Adresses
        Address

In the past, I've tried to create a general method for retrieval that eliminates the IsDeleted objects, by creating an "override" to the tables in the DataContext (also a custom class because it automatically handles the audit fields).

For every table you find in a DataContext:

public DbSet<Company> Companies { get; set; }

We have added a second table that only returns non-deleted items.

public IQueryable<Company> _Companies 
{
    get { return this.Companies .Where(co => !co.IsDeleted); }
}

So we call MyDataContext._Companies instead of MyDataContext.Companies. This works as intended. It nicely filter away the deleted items.

However, we noticed that the same is not true for the subsequent .Include() statement. If I call:

var companies = MyDataContext._Companies.Include(x => x.Departments);

//...

The deleted Departments from a Company are returned as well.

In our current situation, most of the core business logic has already been implemented, and these include statements are all over the place. They mostly relate to security. I can change all statements, but I'd rather first look for a way to do this without impacting the existing code too much.
This is the first application where the size of the queries does not allow us to call every set of entities separately (by using only the direct tables as opposed to include statements).

So my question is twofold:

  • Can I override the .Include(Func<x,y>) method to automatically include a check on the IsDeleted flag for the selected entity?
  • If overriding if possible, how do I combine the passed lambda expression with the additional check I want to execute?

So by calling

someTable.Include(x => x.MyRelatedEntity);

It would actually execute:

/* Not sure if this is proper syntax. I hope it explains what I'm trying to accomplish. */
someTable.Include(x => x.MyRelatedEntity.Where(y => !y.IsDeleted));

Could anyone point me in the right direction? Much appreciated!

Note: I know there's not much code in my question. But I'm not even sure on what level I can implement this. If Include can't be overridden, is there another way?

Update

I implemented the suggested solution, but am runnign into an issue with all database calls. The error is as follows:

Problem in mapping fragments starting at line 245:Condition member 'Company.IsDeleted' with a condition other than 'IsNull=False' is mapped. Either remove the condition on Company.IsDeleted or remove it from the mapping.

Reading up on this issue, it seems that if I use IsDeleted as a condition (i.e. the suggested solution), I cannot still use it as a property.

The problem then becomes: how do I delete something? Once it is deleted, it should never be returned. But a non-deleted item should be able to be deleted.

Is there any way in which I can filter the returned items by IsDeleted, but am still allowed to set it to true and save it?


Solution

  • The solution you're looking for is to require entities have an IsDeleted value of false:

    modelBuilder.Entity<Company>()
        .Map( emc => emc.Requires( "IsDeleted" ).HasValue( false ) );
    

    Now only companies with IsDeleted == false will be retrieved from the DB

    Update from comment:

    modelBuilder.Entity<Company>()
        .Map( emc => 
        {
            emc.MapInheritedProperties();
            emc.Requires( "IsDeleted" ).HasValue( false );
        } )
        .Ignore( c => c.IsDeleted );
    

    Update: test code which was successful (helper methods found here):

    [Table("EntityA")]
    public partial class EntityA
    {
        public int EntityAId { get; set; }
        public string Description { get; set; }
    
    
        public virtual EntityB PrimaryEntityB { get; set; }
    
        public virtual EntityB AlternativeEntityB { get; set; }
    
        public bool IsDeleted { get; set; }
    }
    
    [Table("EntityB")]
    public partial class EntityB
    {
        public int EntityBId { get; set; }
        public string Description { get; set; }
    
        [InverseProperty("PrimaryEntityB")]
        public virtual ICollection<EntityA> EntityAsViaPrimary { get; set; }
        [InverseProperty( "AlternativeEntityB" )]
        public virtual ICollection<EntityA> EntityAsViaAlternative { get; set; }
    }
    
    public partial class TestEntities : DbContext
    {
        public TestEntities()
            : base("TestEntities")
        {
            Database.SetInitializer( new DatabaseInitializer() );
        }
    
        protected override void OnModelCreating(DbModelBuilder modelBuilder)
        {
            modelBuilder.Entity<EntityA>()
                .Map( emc =>
                    {
                        emc.Requires( "IsDeleted" ).HasValue( false );
                    } )
                    .Ignore( a => a.IsDeleted );
        }
    
        public override int SaveChanges()
        {
            foreach( var entry in this.ChangeTracker.Entries<EntityA>() )
            {
                if( entry.State == EntityState.Deleted )
                {
                    SoftDelete( entry );
                }
            }
    
            return base.SaveChanges();
        }
    
        private void SoftDelete( DbEntityEntry entry )
        {
            var entityType = entry.Entity.GetType();
    
            var tableName = GetTableName( entityType );
            var pkName = GetPrimaryKeyName( entityType );
    
            var deleteSql = string.Format( "update {0} set IsDeleted = 1 where {1} = @id",
                tableName,
                pkName );
    
            Database.ExecuteSqlCommand( deleteSql, new SqlParameter( "@id", entry.OriginalValues[ pkName ] ) );
    
            entry.State = EntityState.Detached;
        }
    
        private string GetPrimaryKeyName( Type type )
        {
            return GetEntitySet( type ).ElementType.KeyMembers[ 0 ].Name;
        }
    
        private string GetTableName( Type type )
        {
            EntitySetBase es = GetEntitySet( type );
    
            return string.Format( "[{0}].[{1}]",
                es.MetadataProperties[ "Schema" ].Value,
                es.MetadataProperties[ "Table" ].Value );
        }
        private EntitySetBase GetEntitySet( Type type )
        {
            ObjectContext octx = ( ( IObjectContextAdapter )this ).ObjectContext;
    
            string typeName = ObjectContext.GetObjectType( type ).Name;
    
            var es = octx.MetadataWorkspace
                            .GetItemCollection( DataSpace.SSpace )
                            .GetItems<EntityContainer>()
                            .SelectMany( c => c.BaseEntitySets
                                            .Where( e => e.Name == typeName ) )
                            .FirstOrDefault();
    
            if( es == null )
                throw new ArgumentException( "Entity type not found in GetTableName", typeName );
    
            return es;
        }
    
        public DbSet<EntityA> EntityAs { get; set; }
        public DbSet<EntityB> EntityBs { get; set; }
    }
    

    Application code:

    class Program
    {
        static void Main(string[] args)
        {
            using( var db = new TestEntities() )
            {
                var a0 = new EntityA()
                    {
                        EntityAId = 1,
                        Description = "hi"
                    };
    
                var a1 = new EntityA()
                    {
                        EntityAId = 2,
                        Description = "bye"
                    };
    
                db.EntityAs.Add( a0 );
                db.EntityAs.Add( a1 );
    
                var b = new EntityB()
                {
                    EntityBId = 1,
                    Description = "Big B"
                };
    
                a1.PrimaryEntityB = b;
    
                db.SaveChanges();
    
                // this prints "1"
                Console.WriteLine( b.EntityAsViaPrimary.Count() );
    
                db.EntityAs.Remove( a1 );
    
                db.SaveChanges();
    
                // this prints "0"
                Console.WriteLine( b.EntityAsViaPrimary.Count() );
            }
    
            var input = Console.ReadLine();
        }
    }