Search code examples
c#linqentity-framework-corelinq-expressions

How to use expressions to build a LINQ query dynamically when using an interface to get the column name?


I'm using Entity Framework Core to store and retrieve some data. I'm trying to write a general-purpose method that will work on any DbSet<T> in order to avoid code duplication. This method runs a LINQ query against the set, for which it needs to know the "key" column (ie. the primary key of the table).

To help with this I've defined an interface that returns the name of the property that represents the key column. Entities then implement this interface. Hence I have something like this:

interface IEntityWithKey
{
    string KeyPropertyName { get; }
}

class FooEntity : IEntityWithKey
{
    [Key] public string FooId { get; set; }
    [NotMapped] public string KeyPropertyName => nameof(FooId);
}

class BarEntity : IEntityWithKey
{
    [Key] public string BarId { get; set; }
    [NotMapped] public string KeyPropertyName => nameof(BarId);
}

The method I'm trying to write has this signature:

static List<TKey> GetMatchingKeys<TEntity, TKey>(DbSet<TEntity> dbSet, List<TKey> keysToFind)
    where TEntity : class, IEntityWithKey

Basically, given a DbSet containing entities of type TEntity, and a list of keys of type TKey, the method should return a list of keys that currently exist in the related table in the database.

The query looks like this:

dbSet.Where(BuildWhereExpression()).Select(BuildSelectExpression()).ToList()

In BuildWhereExpression I'm trying to create an appropriate Expression<Func<TEntity, bool>>, and in BuildSelectExpression I'm trying to create an appropriate Expression<Func<TEntity, TKey>>. However I'm struggling with just creating the Select() expression, which is the easier of the two. Here's what I have so far:

Expression<Func<TEntity, TKey>> BuildSelectExpression()
{
    // for a FooEntity, would be:  x => x.FooId
    // for a BarEntity, would be:  x => x.BarId

    ParameterExpression parameter = Expression.Parameter(typeof(TEntity));
    MemberExpression property1 = Expression.Property(parameter, nameof(IEntityWithKey.KeyPropertyName));
    MemberExpression property2 = Expression.Property(parameter, property1.Member as PropertyInfo);
    UnaryExpression result = Expression.Convert(property2, typeof(TKey));
    return Expression.Lambda<Func<TEntity, TKey>>(result, parameter);
}

This runs, and the query that gets passed to the database looks correct, but all I get back is a list of the key property name. For example, called like this:

List<string> keys = GetMatchingKeys(context.Foos, new List<string> { "foo3", "foo2" });

It generates this query, which looks good (note: no Where() implementation yet):

SELECT "f"."FooId"
FROM "Foos" AS "f"

But the query just returns a list containing "FooId" rather than the actual IDs stored in the database.

I feel like I am close to a solution but I'm just going around in circles a bit with the expression stuff, having not done much of it before. If anyone can help with the Select() expression that would be a start.

Here is the full code:

using System;
using System.Collections.Generic;
using System.ComponentModel.DataAnnotations;
using System.ComponentModel.DataAnnotations.Schema;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;

namespace StackOverflow
{
    interface IEntityWithKey
    {
        string KeyPropertyName { get; }
    }

    class FooEntity : IEntityWithKey
    {
        [Key] public string FooId { get; set; }
        [NotMapped] public string KeyPropertyName => nameof(FooId);
    }

    class BarEntity : IEntityWithKey
    {
        [Key] public string BarId { get; set; }
        [NotMapped] public string KeyPropertyName => nameof(BarId);
    }

    class TestContext : DbContext
    {
        public TestContext(DbContextOptions options) : base(options) { }
        public DbSet<FooEntity> Foos { get; set; }
        public DbSet<BarEntity> Bars { get; set; }
    }

    class Program
    {
        static async Task Main()
        {
            IServiceCollection services = new ServiceCollection();
            services.AddDbContext<TestContext>(
            options => options.UseSqlite("Data Source=./test.db"),
                contextLifetime: ServiceLifetime.Scoped,
                optionsLifetime: ServiceLifetime.Singleton);
            services.AddLogging(
                builder =>
                {
                    builder.AddConsole(c => c.IncludeScopes = true);
                    builder.AddFilter(DbLoggerCategory.Infrastructure.Name, LogLevel.Error);
                });
            IServiceProvider serviceProvider = services.BuildServiceProvider();

            var context = serviceProvider.GetService<TestContext>();
            context.Database.EnsureDeleted();
            context.Database.EnsureCreated();

            context.Foos.AddRange(new FooEntity { FooId = "foo1" }, new FooEntity { FooId = "foo2" });
            context.Bars.Add(new BarEntity { BarId = "bar1" });
            await context.SaveChangesAsync();

            List<string> keys = GetMatchingKeys(context.Foos, new List<string> { "foo3", "foo2" });
            Console.WriteLine(string.Join(", ", keys));

            Console.WriteLine("DONE");
            Console.ReadKey(intercept: true);
        }

        static List<TKey> GetMatchingKeys<TEntity, TKey>(DbSet<TEntity> dbSet, List<TKey> keysToFind)
            where TEntity : class, IEntityWithKey
        {
            return dbSet
                //.Where(BuildWhereExpression())   // commented out because not working yet
                .Select(BuildSelectExpression()).ToList();

            Expression<Func<TEntity, bool>> BuildWhereExpression()
            {
                // for a FooEntity, would be:  x => keysToFind.Contains(x.FooId)
                // for a BarEntity, would be:  x => keysToFind.Contains(x.BarId)

                throw new NotImplementedException();
            }

            Expression<Func<TEntity, TKey>> BuildSelectExpression()
            {
                // for a FooEntity, would be:  x => x.FooId
                // for a BarEntity, would be:  x => x.BarId

                ParameterExpression parameter = Expression.Parameter(typeof(TEntity));
                MemberExpression property1 = Expression.Property(parameter, nameof(IEntityWithKey.KeyPropertyName));
                MemberExpression property2 = Expression.Property(parameter, property1.Member as PropertyInfo);
                UnaryExpression result = Expression.Convert(property2, typeof(TKey));
                return Expression.Lambda<Func<TEntity, TKey>>(result, parameter);
            }
        }
    }
}

This uses the following NuGet packages:

  • Microsoft.EntityFrameworkCore, Version 3.0.0
  • Microsoft.EntityFrameworkCore.Sqlite, Version 3.0.0
  • Microsoft.Extensions.DependencyInjection, Version 3.0.0
  • Microsoft.Extensions.Logging.Console, Version 3.0.0

Solution

  • In this case IEntityWithKey interface is redundant. To access KeyPropertyName value from BuildSelectExpression method you would need to have entity instance, but you have only Type object.

    You can use reflection to find key property name:

    Expression<Func<TEntity, TKey>> BuildSelectExpression()
    {
        // Find key property
        PropertyInfo keyProperty = typeof(TEntity).GetProperties()
            .Where(p => p.GetCustomAttribute<KeyAttribute>() != null)
            .Single();
    
        ParameterExpression parameter = Expression.Parameter(typeof(TEntity));
        MemberExpression result = Expression.Property(parameter, keyProperty);
        // UnaryExpression result = Expression.Convert(property1, typeof(TKey)); this is also redundant
        return Expression.Lambda<Func<TEntity, TKey>>(result, parameter);
    }