Search code examples
c#linqtuplesmorelinq

Find min and max of cumulative sum in Linq


I have the following function which I am using to find the terminal accumulative positive and negative value, which is working:

public class CumulativeTotal
{
    [Test]
    public void CalculatesTerminalValue()
    {
        IEnumerable<decimal> sequence = new decimal[] { 10, 20, 20, -20, -50, 10 };

        var values = FindTerminalValues(sequence);
        Assert.That(values.Item1, Is.EqualTo(-20));
        Assert.That(values.Item2, Is.EqualTo(50));

        Assert.Pass();
    }

    public static Tuple<decimal,decimal> FindTerminalValues(IEnumerable<decimal> values)
    {
        decimal largest = 0;
        decimal smallest = 0;
        decimal current = 0;

        foreach (var value in values)
        {
            current += value;
            if (current > largest)
                largest = current;
            else if (current < smallest)
                smallest = current;
        }

        return new Tuple<decimal, decimal>(smallest,largest);
    }
}

However, in the interests of learning, how could i implement with Linq?

I can see a package MoreLinq, but not sure where to start!


Solution

  • The major flaw in the code you've presented is that if the running sum of the the sequence stays below zero or above zero the whole time then the algorithm incorrectly returns zero as one of the terminals.

    Take this:

    IEnumerable<decimal> sequence = new decimal[] { 10, 20, };
    

    Your current algorithm returns (0, 30) when it should be (10, 30).

    To correct that you must start with the first value of the sequence as the default minimum and maximum.

    Here's an implementation that does that:

    public static (decimal min, decimal max) FindTerminalValues(IEnumerable<decimal> values)
    {
        if (!values.Any())
            throw new System.ArgumentException("no values");
            
        decimal first = values.First();
        
        IEnumerable<decimal> scan = values.Scan((x, y) => x + y);
    
        return scan.Aggregate(
            (min: first, max: first),
            (a, x) =>
            (
                min: x < a.min ? x : a.min, 
                max: x > a.max ? x : a.max)
            );
    }
    

    It uses System.Interactive to get the Scan operator (but you could use MoreLinq.

    However, the one downside to this approach is that IEnumerable<decimal> is not guaranteed to return the same values every time. You either need to (1) pass in a decimal[], List<decimal>, or other structure that will always return the same sequence, or (2) ensure you only iterate the IEnumerable<decimal> once.

    Here's how to do (2):

    public static (decimal min, decimal max) FindTerminalValues(IEnumerable<decimal> values)
    {
        var e = values.GetEnumerator();
        if (!e.MoveNext())
            throw new System.ArgumentException("no values");
    
        var terminal = (min: e.Current, max: e.Current);
        decimal value = e.Current;
    
        while (e.MoveNext())
        {
            value += e.Current;
            terminal = (Math.Min(value, terminal.min), Math.Max(value, terminal.max));
        }
    
        return terminal;
    }