Search code examples
c#linqc#-4.0expressionvisitor

Visiting IEnumerable<T> children


Here what we want to do.

We have data from the database that we need to format to make a report, including some calculation (Sum, Averages, and field to field calculation (ex : x.a / x.b)).

One of the limitations is that if, in a sum for exemple, one of the data is null, -1 or -2 we have to stop the calculation and display '-'. Since we have many reports to produce, with the same logic and many calculation in each, we want to centralise this logic. For now, the code we produce allow us to check for field to field calculation (x.a / x.b for exemple), but can't allow us to check for group total (ex: x.b / SUM(x.a))

Test case

Rules

  • The calcul should not be done if one of the value used in the calcul is -1, -2 or null. In this case, return "-" if you find -1 or null, and "C" if you find -2
  • If you have multiple "bad values" in the calcul, you need to respect a priority defined like this: null -> -1 -> -2. This priority is independant of the level where the value is in the calcul

Tests

Simple calcul

object: new DataInfo { A = 10, B = 2, C = 4 }
calcul: x => x.A / x.B + x.C
result: 9
object: new DataInfo { A = 10, B = 2, C = -2 }
calcul: x => x.A / x.B + x.C
result: C (because you have a '-2' value in the calcul)
object: new DataInfo { A = 10, B = -2, C = null }
calcul: x => x.A / x.B + x.C
result: - (because you have a 'null' value in the calcul and it win on the -2 value)

Complex calcul

object: var list = new List();
        list.Add(new DataInfo { A = 10, B = 2, C = 4 });
        list.Add(new DataInfo { A = 6, B = 3, C = 2 });
calcul: list.Sum(x => x.A / x.B + list.Max(y => y.C))
result: 15
object: var list = new List();
        list.Add(new DataInfo { A = 10, B = 2, C = 4 });
        list.Add(new DataInfo { A = 6, B = 3, C = -2 });
calcul: list.Sum(x => x.A / x.B + list.Max(y => y.C))
result: C (because you have a '-2' value in the calcul)

What we have done so far

Here the code we have to handle simple calculs, based on this thread:
How to extract properties used in a Expression<Func<T, TResult>> query and test their value?

We have created a strongly type class that perform a calcul and return the result as a String. But if any part of the expression is equal to a special value, the calculator has to return a special character.

It works well for a simple case, like this one:

var data = new Rapport1Data() { UnitesDisponibles = 5, ... };
var q = new Calculator<Rapport1Data>()
    .Calcul(data, y => y.UnitesDisponibles, "N0");

But I need to be able to perform something more complicated like:

IEnumerable<Rapport1Data> data = ...;
var q = new Calculator<IEnumerable<Rapport1Data>>()
    .Calcul(data, x => x.Sum(y => y.UnitesDisponibles), "N0");

When we start encapsulating or data in IEnurmarable<> we get an error:

Object does not match target type

As we understand it, it's because the Sub-Expression y => y.UnitesDisponibles is being applied to the IEnumerable instead of the Rapport1Data.

How can we fix it to ensure that it will be fully recursive if we some day have complex expression like:

IEnumerable<IEnumerable<Rapport1Data>> data = ...;
var q = new Calculator<IEnumerable<IEnumerable<Rapport1Data>>>()
    .Calcul(data,x => x.Sum(y => y.Sum(z => z.UnitesDisponibles)), "N0");

Classes we've built

public class Calculator<T>
{
    public string Calcul(
        T data,
        Expression<Func<T, decimal?>> query,
        string format)
    {
        var rulesCheckerResult = RulesChecker<T>.Check(data, query);

        // l'ordre des vérifications est importante car il y a une gestion
        // des priorités des codes à retourner!
        if (rulesCheckerResult.HasManquante)
        {
            return TypeDonnee.Manquante.ReportValue;
        }

        if (rulesCheckerResult.HasDivisionParZero)
        {
            return TypeDonnee.DivisionParZero.ReportValue;
        }

        if (rulesCheckerResult.HasNonDiffusable)
        {
            return TypeDonnee.NonDiffusable.ReportValue;
        }

        if (rulesCheckerResult.HasConfidentielle)
        {
            return TypeDonnee.Confidentielle.ReportValue;
        }

        // if the query respect the rules, apply the query and return the
        // value
        var result = query.Compile().Invoke(data);

        return result != null
            ? result.Value.ToString(format)
            : TypeDonnee.Manquante.ReportValue;
    }
}

and the Custom ExpressionVisitor

class RulesChecker<T> : ExpressionVisitor
{
    private readonly T data;
    private bool hasConfidentielle = false;
    private bool hasNonDiffusable = false;
    private bool hasDivisionParZero = false;
    private bool hasManquante = false;

    public RulesChecker(T data)
    {
        this.data = data;
    }

    public static RulesCheckerResult Check(T data, Expression expression)
    {
        var visitor = new RulesChecker<T>(data);
        visitor.Visit(expression);

        return new RulesCheckerResult(
            visitor.hasConfidentielle,
            visitor.hasNonDiffusable,
            visitor.hasDivisionParZero,
            visitor.hasManquante);
    }

    protected override Expression VisitBinary(BinaryExpression node)
    {
        if (!this.hasDivisionParZero &&
            node.NodeType == ExpressionType.Divide &&
            node.Right.NodeType == ExpressionType.MemberAccess)
        {
            var rightMemeberExpression = (MemberExpression)node.Right;
            var propertyInfo = (PropertyInfo)rightMemeberExpression.Member;
            var value = Convert.ToInt32(propertyInfo.GetValue(this.data, null));

            this.hasDivisionParZero = value == 0;
        }

        return base.VisitBinary(node);
    }

    protected override Expression VisitMember(MemberExpression node)
    {
        // Si l'un d'eux n'est pas à true, alors continuer de faire les tests
        if (!this.hasConfidentielle ||
            !this.hasNonDiffusable ||
            !this.hasManquante)
        {
            var propertyInfo = (PropertyInfo)node.Member;
            object value = propertyInfo.GetValue(this.data, null);
            int? valueNumber = MTO.Framework.Common.Convert.To<int?>(value);

            // Si la valeur est à true, il n'y a pas lieu de tester davantage
            if (!this.hasManquante)
            {
                this.hasManquante =
                    valueNumber == TypeDonnee.Manquante.BdValue;
            }

            // Si la valeur est à true, il n'y a pas lieu de tester davantage
            if (!this.hasConfidentielle)
            {
                this.hasConfidentielle =
                    valueNumber == TypeDonnee.Confidentielle.BdValue;
            }

            // Si la valeur est à true, il n'y a pas lieu de tester davantage
            if (!this.hasNonDiffusable)
            {
                this.hasNonDiffusable =
                    valueNumber == TypeDonnee.NonDiffusable.BdValue;
            }
        }

        return base.VisitMember(node);
    }
}

[UPDATE] Adding more detail on what we want to do


Solution

  • There are a few things that you need to change to get this to work:

    • Create a new ExpressionVisitor that will preprocess your expression to execute the aggregates.
    • Use the new ExpressionVisitor in the Calculator.Calcul method.
    • Modify the RulesChecker to include an override for the VisitConstant method. This new method needs to include the same logic that is in the VisitMember method.
    • Modify the RulesChecker VisitBinary method to check the divide by zero condition for ConstantExpressions.

    Here is a rough example of what I think needs to be done.

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Linq;
    using System.Linq.Expressions;
    using System.Reflection;
    
    namespace WindowsFormsApplication1 {
        internal static class Program {
            [STAThread]
            private static void Main() {
                var calculator = new Calculator();
    
                //// DivideByZero - the result should be -1
                var data1 = new DataInfo { A = 10, B = 0, C = 1 };
                Expression<Func<DataInfo, decimal?>> expression1 = x => x.A / x.B + x.C;
                var result1 = calculator.Calcul(data1, expression1, "N0");
    
                //// Negative 1 - the result should be -
                var data2 = new DataInfo { A = 10, B = 5, C = -1 };
                Expression<Func<DataInfo, decimal?>> expression2 = x => x.A / x.B + x.C;
                var result2 = calculator.Calcul(data2, expression2, "N0");
    
                //// Negative 2 - the result should be C
                var data3 = new DataInfo { A = 10, B = 5, C = -2 };
                Expression<Func<DataInfo, decimal?>> expression3 = x => x.A / x.B + x.C;
                var result3 = calculator.Calcul(data3, expression3, "N0");
    
                //// the result should be 3
                var data4 = new DataInfo { A = 10, B = 5, C = 1 };
                Expression<Func<DataInfo, decimal?>> expression4 = x => x.A / x.B + x.C;
                var result4 = calculator.Calcul(data4, expression4, "N0");
    
                //// DivideByZero - the result should be -1
                var data5 = new List<DataInfo> {
                                        new DataInfo {A = 10, B = 0, C = 1},
                                        new DataInfo {A = 10, B = 0, C = 1}
                            };
    
                Expression<Func<IEnumerable<DataInfo>, decimal?>> expression5 = x => x.Sum(y => y.A) / x.Sum(y => y.B) + x.Sum(y => y.C);
                var result5 = calculator.Calcul(data5, expression5, "N0");
    
                //// the result should be 4
                var data6 = new List<DataInfo> {
                                        new DataInfo {A = 10, B = 5, C = 1},
                                        new DataInfo {A = 10, B = 5, C = 1}
                            };
    
                Expression<Func<IEnumerable<DataInfo>, decimal?>> expression6 = x => x.Sum(y => y.A) / x.Sum(y => y.B) + x.Sum(y => y.C);
                var result6 = calculator.Calcul(data6, expression6, "N0");
    
                //// the result should be -
                var data7 = new List<DataInfo> {
                                        new DataInfo {A = 10, B = 5, C = -1},
                                        new DataInfo {A = 10, B = 5, C = 1}
                            };
    
                Expression<Func<IEnumerable<DataInfo>, decimal?>> expression7 = x => x.Sum(y => y.A) / x.Sum(y => y.B) + x.Sum(y => y.C);
                var result7 = calculator.Calcul(data7, expression7, "N0");
    
                //// the result should be 14
                var c1 = 1;
                var c2 = 2;
    
                var data8 = new DataInfo { A = 10, B = 1, C = 1 };
                Expression<Func<DataInfo, decimal?>> expression8 = x => x.A / x.B + x.C + c1 + c2;
                var result8 = calculator.Calcul(data8, expression8, "N0");
            }
        }
    
        public class Calculator {
            public string Calcul<T>(T data, LambdaExpression query, string format) {
                string reportValue;
    
                if (HasIssue(data, query, out reportValue)) {
                    return reportValue;
                }
    
                // executes the aggregates
                query = (LambdaExpression)ExpressionPreProcessor.PreProcessor(data, query);
    
                // checks the rules against the results of the aggregates
                if (HasIssue(data, query, out reportValue)) {
                    return reportValue;
                }
    
                Delegate lambda = query.Compile();
                decimal? result = (decimal?)lambda.DynamicInvoke(data);
    
                return result != null
                    ? result.Value.ToString(format)
                    : TypeDonnee.Manquante.ReportValue;
            }
    
            private bool HasIssue(object data, LambdaExpression query, out string reportValue) {
                reportValue = null;
    
                var rulesCheckerResult = RulesChecker.Check(data, query);
    
                if (rulesCheckerResult.HasManquante) {
                    reportValue = TypeDonnee.Manquante.ReportValue;
                }
    
                if (rulesCheckerResult.HasDivisionParZero) {
                    reportValue = TypeDonnee.DivisionParZero.ReportValue;
                }
    
                if (rulesCheckerResult.HasNonDiffusable) {
                    reportValue = TypeDonnee.NonDiffusable.ReportValue;
                }
    
                if (rulesCheckerResult.HasConfidentielle) {
                    reportValue = TypeDonnee.Confidentielle.ReportValue;
                }
    
                return reportValue != null;
            }
        }
    
        internal class ExpressionPreProcessor : ExpressionVisitor {
            private readonly object _source;
    
            public static Expression PreProcessor(object source, Expression expression) {
                if (!IsValidSource(source)) {
                    return expression;
                }
    
                var visitor = new ExpressionPreProcessor(source);
    
                return visitor.Visit(expression);
            }
    
            private static bool IsValidSource(object source) {
                if (source == null) {
                    return false;
                }
    
                var type = source.GetType();
    
                return type.IsGenericType && type.GetInterface("IEnumerable") != null;
            }
    
            public ExpressionPreProcessor(object source) {
                this._source = source;
            }
    
            protected override Expression VisitMethodCall(MethodCallExpression node) {
                if (node.Method.DeclaringType == typeof(Enumerable) && node.Arguments.Count == 2) {
    
                    switch (node.Method.Name) {
                        case "Count":
                        case "Min":
                        case "Max":
                        case "Sum":
                        case "Average":
                            var lambda = node.Arguments[1] as LambdaExpression;
                            var lambaDelegate = lambda.Compile();
                            var value = node.Method.Invoke(null, new object[] { this._source, lambaDelegate });
    
                            return Expression.Constant(value);
                    }
                }
    
                return base.VisitMethodCall(node);
            }
        }
    
        internal class RulesChecker : ExpressionVisitor {
            private readonly object data;
            private bool hasConfidentielle = false;
            private bool hasNonDiffusable = false;
            private bool hasDivisionParZero = false;
            private bool hasManquante = false;
    
            public RulesChecker(object data) {
                this.data = data;
            }
    
            public static RulesCheckerResult Check(object data, Expression expression) {
                if (IsIEnumerable(data)) {
                    var result = new RulesCheckerResult(false, false, false, false);
    
                    IEnumerable dataItems = (IEnumerable)data;
    
                    foreach (object dataItem in dataItems) {
                        result = MergeResults(result, GetResults(dataItem, expression));
                    }
    
                    return result;
    
                }
                else {
                    return GetResults(data, expression);
                }
            }
    
            private static RulesCheckerResult MergeResults(RulesCheckerResult results1, RulesCheckerResult results2) {
                var hasConfidentielle = results1.HasConfidentielle || results2.HasConfidentielle;
                var hasDivisionParZero = results1.HasDivisionParZero || results2.HasDivisionParZero;
                var hasManquante = results1.HasManquante || results2.HasManquante;
                var hasNonDiffusable = results1.HasNonDiffusable || results2.HasNonDiffusable;
    
                return new RulesCheckerResult(hasConfidentielle, hasNonDiffusable, hasDivisionParZero, hasManquante);
            }
    
            private static RulesCheckerResult GetResults(object data, Expression expression) {
                var visitor = new RulesChecker(data);
                visitor.Visit(expression);
    
                return new RulesCheckerResult(
                    visitor.hasConfidentielle,
                    visitor.hasNonDiffusable,
                    visitor.hasDivisionParZero,
                    visitor.hasManquante);
            }
    
            private static bool IsIEnumerable(object source) {
                if (source == null) {
                    return false;
                }
    
                var type = source.GetType();
    
                return type.IsGenericType && type.GetInterface("IEnumerable") != null;
            }
    
            protected override Expression VisitBinary(BinaryExpression node) {
                if (!this.hasDivisionParZero && node.NodeType == ExpressionType.Divide) {
                    if (node.Right.NodeType == ExpressionType.MemberAccess) {
                        var rightMemeberExpression = (MemberExpression)node.Right;
                        var propertyInfo = (PropertyInfo)rightMemeberExpression.Member;
                        var value = Convert.ToInt32(propertyInfo.GetValue(this.data, null));
    
                        this.hasDivisionParZero = value == 0;
                    }
    
                    if (node.Right.NodeType == ExpressionType.Constant) {
                        var rightConstantExpression = (ConstantExpression)node.Right;
                        var value = Convert.ToInt32(rightConstantExpression.Value);
    
                        this.hasDivisionParZero = value == 0;
                    }
    
                }
    
                return base.VisitBinary(node);
            }
    
            protected override Expression VisitConstant(ConstantExpression node) {
                this.CheckValue(this.ConvertToNullableInt(node.Value));
    
                return base.VisitConstant(node);
            }
    
            protected override Expression VisitMember(MemberExpression node) {
                if (!this.hasConfidentielle || !this.hasNonDiffusable || !this.hasManquante) {
                    var propertyInfo = node.Member as PropertyInfo;
    
                    if (propertyInfo != null) {
                        var value = propertyInfo.GetValue(this.data, null);
    
                        this.CheckValue(this.ConvertToNullableInt(value));
                    }
                }
    
                return base.VisitMember(node);
            }
    
            private void CheckValue(int? value) {
                if (!this.hasManquante) {
                    this.hasManquante = value == TypeDonnee.Manquante.BdValue;
                }
    
                if (!this.hasConfidentielle) {
                    this.hasConfidentielle = value == TypeDonnee.Confidentielle.BdValue;
                }
    
                if (!this.hasNonDiffusable) {
                    this.hasNonDiffusable = value == TypeDonnee.NonDiffusable.BdValue;
                }
            }
    
            private int? ConvertToNullableInt(object value) {
                if (!value.GetType().IsPrimitive) {
                    return int.MinValue;
                }
    
                // MTO.Framework.Common.Convert.To<int?>(value);
                return (int?)value;
            }
        }
    
        class RulesCheckerResult {
            public bool HasConfidentielle { get; private set; }
            public bool HasNonDiffusable { get; private set; }
            public bool HasDivisionParZero { get; private set; }
            public bool HasManquante { get; private set; }
    
            public RulesCheckerResult(bool hasConfidentielle, bool hasNonDiffusable, bool hasDivisionParZero, bool hasManquante) {
                this.HasConfidentielle = hasConfidentielle;
                this.HasNonDiffusable = hasNonDiffusable;
                this.HasDivisionParZero = hasDivisionParZero;
                this.HasManquante = hasManquante;
            }
        }
    
        class TypeDonnee {
            public static readonly TypeValues Manquante = new TypeValues(null, "-");
            public static readonly TypeValues Confidentielle = new TypeValues(-1, "-");
            public static readonly TypeValues NonDiffusable = new TypeValues(-2, "C");
            public static readonly TypeValues DivisionParZero = new TypeValues(0, "-1");
        }
    
        class TypeValues {
            public int? BdValue { get; set; }
            public string ReportValue { get; set; }
    
            public TypeValues(int? bdValue, string reportValue) {
                this.BdValue = bdValue;
                this.ReportValue = reportValue;
            }
        }
    
        class DataInfo {
            public int A { get; set; }
            public int B { get; set; }
            public int C { get; set; }
        }
    }