Search code examples
c#performancelinqlinq-to-objects

Improve performance of Moving Average calculation. IEnumerable vs List, foreach vs for, ElementAt,


I created a class to calculate a Exponential Moving Average:

public class ExponentialMovingAverage {

  public Int32 Period { get; set; }

  public ExponentialMovingAverage(Int32 period = 20) {

    ArgumentOutOfRangeException.ThrowIfNegativeOrZero(period);

    Period = period;
    
  } 

  public override IEnumerable<(DateTimeOffset Stamp, Decimal? ExponentialMovingAverage)> Compute(IEnumerable<(DateTimeOffset Stamp, Decimal? Value)> inputs) {
    
    ArgumentNullException.ThrowIfNull(inputs);

    inputs = inputs.OrderBy(x => x.Stamp);

    Decimal? previous = null;

    Decimal factor = (Decimal)(2d / (Period + 1));

    Decimal? sum = 0;

    Int32 notNulls = 0;

    for (Int32 index = 0; index < inputs.Count(); index++) {

      (DateTimeOffset stamp, Decimal? value) = inputs.ElementAt(index);

      if (value == null) {
        notNulls++;
        yield return (stamp, null);   
        continue;  
      }

      if (index < notNulls + Period - 1) {
        sum += value;
        yield return (stamp, null);   
        continue;
      }
        
      if (index == notNulls + Period - 1) {
        sum += value;
        Decimal? sma = sum / Period;
        previous = sma;
        yield return (stamp, sma);  
        continue;  
      }

      Decimal? ema = previous + (factor * (value - previous));
      previous = ema;
      yield return (stamp, ema);

    } 
    
  } 

} 

And I have the following tests which are passing:

[Fact]
public void Test_AllNonNullInputs() {
    var ema = new ExponentialMovingAverage(3);

    var inputs = new List<(DateTimeOffset Stamp, decimal? Value)> {
        (new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero), 10),
        (new DateTimeOffset(2024, 1, 2, 0, 0, 0, TimeSpan.Zero), 15),
        (new DateTimeOffset(2024, 1, 3, 0, 0, 0, TimeSpan.Zero), 20),
        (new DateTimeOffset(2024, 1, 4, 0, 0, 0, TimeSpan.Zero), 25),
        (new DateTimeOffset(2024, 1, 5, 0, 0, 0, TimeSpan.Zero), 30),
        (new DateTimeOffset(2024, 1, 6, 0, 0, 0, TimeSpan.Zero), 35)
    };
    var output = ema.Compute(inputs).ToList();

    Assert.Equal(6, output.Count);
    Assert.Null(output[0].ExponentialMovingAverage);
    Assert.Null(output[1].ExponentialMovingAverage);
    Assert.Equal(15m, output[2].ExponentialMovingAverage);
    Assert.Equal(20m, output[3].ExponentialMovingAverage);
    Assert.Equal(25m, output[4].ExponentialMovingAverage);
    Assert.Equal(30m, output[5].ExponentialMovingAverage);
}

[Fact]
public void Test_FirstTwoInputsAreNull() {
    var ema = new ExponentialMovingAverage(3);

    var inputs = new List<(DateTimeOffset Stamp, decimal? Value)> {
        (new DateTimeOffset(2024, 1, 1, 0, 0, 0, TimeSpan.Zero), null),
        (new DateTimeOffset(2024, 1, 2, 0, 0, 0, TimeSpan.Zero), null),
        (new DateTimeOffset(2024, 1, 3, 0, 0, 0, TimeSpan.Zero), 20),
        (new DateTimeOffset(2024, 1, 4, 0, 0, 0, TimeSpan.Zero), 25),
        (new DateTimeOffset(2024, 1, 5, 0, 0, 0, TimeSpan.Zero), 30),
        (new DateTimeOffset(2024, 1, 6, 0, 0, 0, TimeSpan.Zero), 35)
    };
    var output = ema.Compute(inputs).ToList();

    Assert.Equal(6, output.Count);
    Assert.Null(output[0].ExponentialMovingAverage);
    Assert.Null(output[1].ExponentialMovingAverage);
    Assert.Null(output[2].ExponentialMovingAverage);
    Assert.Null(output[3].ExponentialMovingAverage);
    Assert.Equal(25m, output[4].ExponentialMovingAverage);
    Assert.Equal(30m, output[5].ExponentialMovingAverage);
}

The calculation seems slow, for example, when calculating an EMA of an EMA with many inputs:

var ema1 = new ExponentialMovingAverage(3);
var outputs1 = ema1.Compute(inputs);
var ema2 = new ExponentialMovingAverage(3);
var outputs2 = ema2.Compute(ema1);

I have been looking at the use of inputs.ElementAt(index) and also the use of foreach vs for and List vs Enumerable.

How can I improve the code including its performance?


Solution

  • Your basic problem is that, since OrderBy() returns an IOrderedEnumerable<TSource> rather than an IList<TSource>, ElementAt(index) will stream through the entire enumerable to get the specified element. This in turn causes the sorting to be re-evaluated for each call to ElementAt() -- a rather substantial n-squared performance penalty.

    To prevent this, replace your for loop with a foreach loop and otherwise make sure never to enumerate the sorted enumerator more than once:

    public class ExponentialMovingAverage {
    
        public int Period { get; }
    
        public ExponentialMovingAverage(int period = 20) {
            ArgumentOutOfRangeException.ThrowIfNegativeOrZero(period);
            Period = period;
        } 
    
        public /*override*/ IEnumerable<(DateTimeOffset Stamp, decimal? ExponentialMovingAverage)> Compute(IEnumerable<(DateTimeOffset Stamp, Decimal? Value)> inputs) {
            ArgumentNullException.ThrowIfNull(inputs);
    
            inputs = inputs.OrderBy(x => x.Stamp);
    
            decimal? previous = null;
            decimal factor = (2M / (Period + 1)); // M is the decimal literal suffix
            decimal? sum = 0;
            int notNulls = 0;
    
            int index = 0;
            foreach ((var stamp, var value) in inputs) {
                if (value == null) {
                    notNulls++;
                    yield return (stamp, null);   
                }
                else if (index < notNulls + Period - 1) {
                    sum += value;
                    yield return (stamp, null);   
                }
                else if (index == notNulls + Period - 1) {
                    sum += value;
                    Decimal? sma = sum / Period;
                    previous = sma;
                    yield return (stamp, sma);  
                }
                else {
                    Decimal? ema = previous + (factor * (value - previous));
                    previous = ema;
                    yield return (stamp, ema);
                }
                index++;
            } 
        } 
    } 
    

    As an aside, I recommend you make the Period property read-only.

    Demo fiddle #1 here.

    Alternatively, if you need to randomly access the sorted list of inputs for some reason, materialize it as a List<T> and use that:

    public /*override*/ IEnumerable<(DateTimeOffset Stamp, decimal? ExponentialMovingAverage)> Compute(IEnumerable<(DateTimeOffset Stamp, Decimal? Value)> inputs) {
        ArgumentNullException.ThrowIfNull(inputs);
    
        var inputList = inputs.OrderBy(x => x.Stamp).ToList();
    
        decimal? previous = null;
        decimal factor = (2M / (Period + 1)); // M is the decimal literal suffix
        decimal? sum = 0;
        int notNulls = 0;
    
        for (int index = 0; index < inputList.Count; index++) {
            (var stamp, var value) = inputList[index];
    
            if (value == null) {
                notNulls++;
                yield return (stamp, null);   
            }
            else if (index < notNulls + Period - 1) {
                sum += value;
                yield return (stamp, null);   
            }
            else if (index == notNulls + Period - 1) {
                sum += value;
                Decimal? sma = sum / Period;
                previous = sma;
                yield return (stamp, sma);  
            }
            else {
                Decimal? ema = previous + (factor * (value - previous));
                previous = ema;
                yield return (stamp, ema);
            }
        } 
    } 
    

    Demo fiddle #2 here.

    Either way, Count() and ElementAt() should not be used more than once per enumerable, as they are not performant unless the enumerable is actually an IList<T> -- in which case the Count and Item[int index] properties of the list should be used instead.

    Incidentally, while the docs for OrderBy() do not seem to explicitly state that it is evaluated each time, it is mentioned in Classification of standard query operators by manner of execution: Deferred:

    Deferred execution means that the operation isn't performed at the point in the code where the query is declared. The operation is performed only when the query variable is enumerated, for example by using a foreach statement. The results of executing the query depend on the contents of the data source when the query is executed rather than when the query is defined. If the query variable is enumerated multiple times, the results might differ every time. Almost all the standard query operators whose return type is IEnumerable<T> or IOrderedEnumerable<TElement> execute in a deferred manner.

    OrderBy() uses deferred nonstreaming execution as stated in the Classification table.

    You can also confirm with some debugging that the ordered enumerable is recomputed for each ElementAt() by tracking how often OrderBy() calls your comparison method:

    public /*override*/ IEnumerable<(DateTimeOffset Stamp, Decimal? ExponentialMovingAverage)> Compute(IEnumerable<(DateTimeOffset Stamp, Decimal? Value)> inputs) {
    
        ArgumentNullException.ThrowIfNull(inputs);
    
        int callCount = 0;
        inputs = inputs.OrderBy(x => { callCount++; return x.Stamp; });
    
        inputs.ElementAt(0);
    
        var firstCallCount = callCount;
    
        inputs.ElementAt(0);
    
        Assert.Equal(firstCallCount, callCount); // FAILS with Xunit.Sdk.EqualException: Assert.Equal() Failure: Values differ
    

    The above code throws because ElementAt(0) ends up re-sorting the inputs for each call.

    Demo fiddle #3 here.