I have a model that looks like this:
Product
-DeleteProduct
└─PreviousProduct (of type SubProduct, not DeleteProduct)
-SubProduct of type SubProduct1, SubProduct2
So in words, the product can be of type DeleteProduct or of type SubProduct, if it's of DeleteProduct then it has a property PreviousProduct (of type SubProduct)
Now I have a the following EF Core linq query like so:
var queryable = context
.Set<Product>()
.OfTypes(new[] { typeof(SubProduct1), typeof(DeleteProduct) })
.Where(p => p.CustomerId == customerId && op.CustomerId != null)
.Where(p => op is SubProduct1 || (op is DeleteProduct && op.PreviousProduct is AccessProduct))
.Select(p => p.ProductId);
With some extension methods, (with thanks to Drew):
public static IQueryable<TEntity> NotOfTypes<TEntity>(this IQueryable<TEntity> query, IEnumerable<Type>? typesEnumerable) where TEntity : class
{
return AddWhere(query, typesEnumerable, types => GetNotTypesPredicate(typeof(TEntity), types));
}
public static IQueryable<TEntity> OfTypes<TEntity>(this IQueryable<TEntity> query, IEnumerable<Type>? typesEnumerable) where TEntity : class
{
return AddWhere(query, typesEnumerable, types => GetOfOnlyTypesPredicate(typeof(TEntity), types));
}
private static IQueryable<TEntity> AddWhere<TEntity>(
this IQueryable<TEntity> query,
IEnumerable<Type>? typesEnumerable,
Func<IReadOnlyList<Type>, LambdaExpression> getNotTypesPredicate) where TEntity : class
{
if (typesEnumerable is null)
{
return query;
}
var types = typesEnumerable.ToArray();
if (!types.Any())
{
return query;
}
var lambda = getNotTypesPredicate(types);
return query.OfType<TEntity>().Where(lambda as Expression<Func<TEntity, bool>> ??
throw new InvalidOperationException("Could not translate to types"));
}
private static LambdaExpression GetNotTypesPredicate(Type baseType, IReadOnlyList<Type> excluded)
{
var param = Expression.Parameter(baseType, "notOfTypeParam");
Expression merged = Expression.Not(Expression.TypeIs(param, excluded[0]));
for (var i = 1; i < excluded.Count; i++)
{
merged = Expression.AndAlso(merged, Expression.Not(Expression.TypeIs(param, excluded[i])));
}
return Expression.Lambda(merged, param);
}
private static LambdaExpression GetOfOnlyTypesPredicate(Type baseType, IReadOnlyList<Type> allowed)
{
var param = Expression.Parameter(baseType, "typeonlyParam");
Expression merged = Expression.TypeIs(param, allowed[0]);
for (var i = 1; i < allowed.Count; i++)
{
merged = Expression.OrElse(merged, Expression.TypeIs(param, allowed[i]));
}
return Expression.Lambda(merged, param);
}
EntityFrameworkCore comes up with the following query (I simplified the query a bit by removing the unneeded parentheses and casts):
DECLARE @__customerId_0 int = 1;
SELECT [p].[ProductId]
FROM [dbo].[Product] AS [p]
LEFT JOIN (
SELECT [p0].[ProductId], [p0].[ProductTypeId]
FROM [dbo].[Product] AS [p0]
WHERE [p0].[ProductTypeId] IN (1, 2, 3, 9, 10, 20, 21, 22, 30, 31)
) AS [t] ON [p].[PreviousProductId] = [t].[ProductId]
WHERE
(
[p].[ProductTypeId] IN (1, 0)
AND [p].[CustomerId] = @__customerId_0
)
AND
(
[p].[ProductTypeId] = 1
OR ([p].[ProductTypeId] = 0 AND [t].[ProductTypeId] = 1)
)
As you can see the OfTypes already and does the [ProductTypeId] IN (1, 0)
I would like to get rid of the unneeded [ProductTypeId] IN (1, 2, 3, 9, 10, 20, 21, 22, 30, 31) and have that changed into ProductTypeId = 1
or ProductTypeId in (1)
How do I do this? Maybe LinqKit can do this? With a nested expression or so?
It is possible, if you add DbContext
as additional paramateter. DbContext
needed to get Model
information.
Usage almost the same:
var queryable = context
.Set<Product>()
.OfTypes(context, new[] { typeof(SubProduct1), typeof(DeleteProduct) })
.Where(p => p.CustomerId == customerId && op.CustomerId != null)
.Where(p => op is SubProduct1 || (op is DeleteProduct && op.PreviousProduct is AccessProduct))
.Select(p => p.ProductId);
And implementation:
public static class TPHExtemsions
{
public static IQueryable<TEntity> OfTypes<TEntity>(this IQueryable<TEntity> query, DbContext context,
IEnumerable<Type>? typesEnumerable)
{
var predicate = BuildPredicate<TEntity>(context.Model, typesEnumerable, false);
if (predicate == null)
return query;
return query.Where(predicate);
}
public static IQueryable<TEntity> NotOfTypes<TEntity>(this IQueryable<TEntity> query, DbContext context,
IEnumerable<Type>? typesEnumerable)
{
var predicate = BuildPredicate<TEntity>(context.Model, typesEnumerable, true);
if (predicate == null)
return query;
return query.Where(predicate);
}
private static Expression<Func<TEntity, bool>> BuildPredicate<TEntity>(IModel model,
IEnumerable<Type>? typesEnumerable, bool isNot)
{
if (typesEnumerable == null)
return null;
// get Discriminator values from model
var discriminatorValues = typesEnumerable.Select(t => model.FindEntityType(t).GetDiscriminatorValue()).ToList();
if (discriminatorValues.Count == 0)
return null;
var et = model.FindEntityType(typeof(TEntity));
var discriminator = et.GetDiscriminatorProperty();
// cast List of objects to discriminator type
var itemsExpression = Expression.Call(typeof(Enumerable), nameof(Enumerable.Cast),
new[] { discriminator.ClrType }, Expression.Constant(discriminatorValues));
var param = Expression.Parameter(typeof(TEntity), "e");
Expression propExpression;
if (discriminator.PropertyInfo == null)
{
// Discriminator is Shadow property, so call via EF.Property(e, "Discriminator")
propExpression = Expression.Call(typeof(EF), nameof(EF.Property), new[] { discriminator.ClrType },
param, Expression.Constant(discriminator.Name));
}
else
{
propExpression = Expression.MakeMemberAccess(param, discriminator.PropertyInfo);
}
// generate Contains
var predicate = (Expression)Expression.Call(typeof(Enumerable), nameof(Enumerable.Contains), new[] { discriminator.ClrType },
itemsExpression, propExpression);
// invert if needed
if (isNot)
predicate = Expression.Not(predicate);
// generate lambda from predicate
var predicateLambda = Expression.Lambda<Func<TEntity, bool>>(predicate, param);
return predicateLambda;
}
}