Search code examples
c#linqentity-frameworklinq-to-entities

Pass linq to entities query as parameter and execute against different objects


I'm looking to implement a cache so that I check to see if specific queries have been executed already, and if they have, return the data using the in memory cache. Rather than implement the cache itself, I can make use of the in memory view that EntityFramework exposes through the Local property of the DBSet.

So, I want to pass the entire query to a method that will run it against the database and/or the in memory cache. However, I'm getting my Expression/Func/IQueryable mixed up and can't see out how to pass the query/expression:

So the base class would be something like this (this obviously isn't working yet...):

public abstract class ServiceBase<TEntity> where TEntity : class
{
    Entities _context; //instantiated in the constructor

    protected List<TEntity> GetData<TEntity>(Expression<Func<TEntity, TEntity>> query)
    {
        if (!HasQueryExecuted(query))
        {
            _context.Set<TEntity>().Select(query).Load();
            AddQueryToExecutedList(query); 
        }
        return _context.Set<TEntity>().Local.AsQueryable().Select(query).ToList();
    }
}

And I would have multiple classes that define their queries something like this:

public class CustomersService : ServiceBase<Customer>
{
    public List<Customer> GetCustomersWithOrders()
    {
        return GetData(c => c.Where(c => c.OrderCount > 0));
    }

    public List<Customer> GetLargestCustomersByOrder(int TopCount)
    {
        return GetData(c => c.OrderBy(c=>c.OrderCount).Take(TopCount));
    }
}
public class ProductsService : ServiceBase<Product>
{
    public List<Customer> GetAllProducts()
    {
        return GetData(p => p);
    }
    public List<Customer> GetMostOrderedProducts(int minimumOrders)
    {
        return GetData(p => p.Where(p=> p.OrderCount > minimumOrders)
                   .OrderByDescending(p=>p.OrderCount));
    }

    public List<Customer> GetAllProducts()
    {
        return GetData(p => p);
    }
}

These are just contrived examples, but the point is the queries could be varied, not just limited to a where clause but making use of all the standard extension methods on IQueryable to query the underlying data.

First of all, is this possible? Can the GetData method define a parameter that can take any query? With the above example, I declared the query parameter to accept the same type as specified for the IQueryable extension Select method, but when I try and call it with a simple Where clause or OrderBy clause, I get various compilation errors such as :

Cannot implicitly convert type 'bool' to 'Entities.Customer

or

Cannot implicitly convert type 'System.Linq.IOrderedEnumerable' to 'Entities.Customer'

And yet, it compiles just fine if I run the same query directly against the context. So what am I doing wrong?


Solution

  • Your issue is this Expression<Func<TEntity, TEntity>> query

    Simply put you are saying that your expression is a function type that takes a TEntity and returns a TEntity. Since your class is typed with Entities.Customer it will expect a function that takes an Entities.Customer and returns an Entities.Customer.

    If you look at your services classes p => p will work fine but p => p.Where() will not because Where returns an IEnumerable.

    I think you should take a different approach and take advantage of lazy-loading.

    public IQueryable<TEntity> GetData<TEntity>() where TEntity : class
    {
        return _context.Set<TEntity>();
    }
    
    public IQueryable<Customer> GetAllProducts()
    {
        return GetData();
    }
    
    public IQueryable<Customer> GetMostOrderedProducts(int minimumOrders)
    {
        return GetData().Where(p => p.OrderCount > minimumOrders)
            .OrderByDescending(p=>p.OrderCount));
    }
    

    You don't need to build the query into GetData because the query can be built at any point up until it's evaluated. If you really wanted you could return List<T> from here but you shouldn't really have to.

    Otherwise this should do the trick for your current code.

    protected List<TEntity> GetData<TEntity>(Func<IEnumerable<TEntity>,
                                             IEnumerable<TEntity>> query) where TEntity : class
    {
        if (!HasQueryExecuted(query))
        {  
            AddQueryToExecutedList(query); 
            return query(_context.Set<TEntity>()).ToList();
        }
        return query(_context.Set<TEntity>().Local).ToList();
    }