Search code examples
c#optimizationienumerableyield-return

Refactoring IEnumerable<T> yield


I have the following code that is used to create simple graphs containing links between nodes

public class Node(string name)
{
    public string Name { get; } = name;
}

public class Link(Node from, Node to)
{
    public Node From { get; } = from;
    public Node To { get; } = to;
}

public class Graph
{
    public List<Node> Nodes = new();
    public List<Link> Links = new();

    public IEnumerable<Node> GetParents(Node node)
    {
        foreach (var link in Links)
        {
            if (link.To == node)
                yield return link.From;
        }
    }

    public IEnumerable<Node> GetChildren(Node node)
    {
        foreach (var link in Links)
        {
            if (link.From == node)
                yield return link.To;
        }
    }
}

As there is a large number of nodes and links and due to the graph being highly dynamic, I chose not to store each nodes parents and children on the Node instance itself (a node can have many parents and many children). But instead, have GetParents and GetChildren helper methods on the Graph class.

However, I learnt something new. In that doing this, there is an allocation for each call made to either of these methods due to backing code required for IEnumerable/yield.

As shown in the following test:

[MemoryDiagnoser]
public class Test
{
    Graph _graph;

    [IterationSetup]
    public void Setup()
    {
        _graph = new Graph();

        Node a = new("A"), b = new("B"), c = new("C");

        _graph.Nodes = [a, b, c];
        _graph.Links = [new Link(a, b), new Link(a, c)];
    }

    [Benchmark]
    public void Run1()
    {
        _graph.GetParents(_graph.Nodes[1]);
    }

    [Benchmark]
    public void Run2()
    {
        _graph.GetParents(_graph.Nodes[1]);
        _graph.GetParents(_graph.Nodes[1]);
    }

    [Benchmark]
    public void RunLoop()
    {
        for (int i = 0; i < 1000; ++i)
            _graph.GetParents(_graph.Nodes[1]);
    }

    [Benchmark]
    public void GetChildren_Raw()
    {
        int count = 0;

        for (int i = 0; i < 1000; ++i)
        {
            foreach (var link in _graph.Links)
            {
                if (link.From == _graph.Nodes[1])
                {
                    count++;
                }
            }
        }
    }
}
| Method          | Mean        | Error       | StdDev       | Median      | Allocated |
|---------------- |------------:|------------:|-------------:|------------:|----------:|
| Run1            |    673.6 ns |    24.84 ns |     69.66 ns |    700.0 ns |     480 B |
| Run2            |    740.0 ns |    28.49 ns |     81.74 ns |    700.0 ns |     560 B |
| RunLoop         | 31,390.9 ns | 4,295.45 ns | 12,597.80 ns | 35,600.0 ns |   80400 B |
| GetChildren_Raw | 14,674.2 ns |   455.90 ns |  1,322.66 ns | 14,100.0 ns |     400 B |

Other than storing parents and children for each node (something I do not want to do as its a very dynamic graph), how could I implement these methods without allocations?


Solution

  • List<T> avoids allocations when used in a foreach statement by having its GetEnumerator() method return a struct List<T>.Enumerator that implements IEnumerator<T>. From the reference source for List<T>:

    public class List<T> : IList<T>, IList, IReadOnlyList<T>
    {
        public Enumerator GetEnumerator()
            => new Enumerator(this);
    
        IEnumerator<T> IEnumerable<T>.GetEnumerator()
            => new Enumerator(this);
    
        IEnumerator IEnumerable.GetEnumerator()
           => new Enumerator(this);
    
        public struct Enumerator : IEnumerator<T>, IEnumerator
        {
             // Implementation omitted
        }
    

    As long as the IEnumerable<T>.GetEnumerator() and IEnumerable.GetEnumerator() methods are implemented explicitly, the runtime will pick up the public GetEnumerator() method and insert it into all foreach loops without boxing the returned enumerator.

    You can adopt the same approach by wrapping List<Link>.Enumerator in your own enumerator struct with the necessary filtering logic. First introduce the following enumerable and enumerator inside Graph:

    public partial class Graph
    {
        public struct NodeEnumerable : IEnumerable<Node>
        {
            readonly Graph graph;
            readonly Node node;
            readonly Func<Link, Node> getFilterNode;
            readonly Func<Link, Node> getReturnNode;
            
            internal NodeEnumerable(Graph graph, Node node, Func<Link, Node> getFilterNode, Func<Link, Node> getReturnNode) =>
                (this.graph, this.node, this.getFilterNode, this.getReturnNode) = (graph, node, getFilterNode, getReturnNode);
            
            public NodeEnumerator GetEnumerator() => new NodeEnumerator(graph.Links.GetEnumerator(), node, GetTo, GetFrom);
            IEnumerator<Node> IEnumerable<Node>.GetEnumerator() => GetEnumerator();
            IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
    
            public struct NodeEnumerator : IEnumerator<Node>, IEnumerator
            {
                List<Link>.Enumerator enumerator;
                readonly Node node;
                readonly Func<Link, Node> getFilterNode;
                readonly Func<Link, Node> getReturnNode;
    
                internal NodeEnumerator(List<Link>.Enumerator enumerator, Node node, Func<Link, Node> getFilterNode, Func<Link, Node> getReturnNode) => 
                    (this.enumerator, this.node, this.getFilterNode, this.getReturnNode) = (enumerator, node, getFilterNode, getReturnNode);
    
                public void Dispose() => enumerator.Dispose();
                public Node Current => getReturnNode(enumerator.Current);
                void IEnumerator.Reset() => enumerator.Reset();
    
                object? IEnumerator.Current => Current;
    
                public bool MoveNext()
                {
                    while (enumerator.MoveNext())
                        if (getFilterNode(enumerator.Current) == node)
                            return true;
                    return false;
                }
            }
        }
    }
    
    public static class EnumeratorExtensions
    {
        // Use to avoid boxing enumerator structs that implement Reset explicitly
        public static void Reset<TEnumerator>(ref this TEnumerator enumerator) where TEnumerator : struct, IEnumerator => 
            enumerator.Reset();
    }
    

    Then modify GetParents() and GetChildren() as follows:

    public partial class Graph
    {
        public List<Node> Nodes = new();
        public List<Link> Links = new();
    
        public NodeEnumerable GetParents(Node node) => new NodeEnumerable(this, node, GetTo, GetFrom);
    
        public IEnumerable<Node> GetChildren(Node node) => new NodeEnumerable(this, node, GetFrom, GetTo);
        
        // Make these delegates static to avoid allocations
        readonly static Func<Link, Node> GetFrom = static n => n.From;
        readonly static Func<Link, Node> GetTo = static n => n.To;
    }
    

    And now foreach loops through the parents and children will no longer allocate memory (as run in .NET 8)

    Environment version: .NET 8.0.8 (8.0.8), Microsoft Windows NT 10.0.19045.0
    
    | Method          | Mean        | Error     | StdDev      | Median      | Allocated |
    |---------------- |------------:|----------:|------------:|------------:|----------:|
    | Run1            |    200.0 ns |   0.00 ns |     0.00 ns |    200.0 ns |     400 B |
    | Run2            |    282.3 ns |  13.30 ns |    38.37 ns |    300.0 ns |     400 B |
    | RunLoop         | 25,766.3 ns | 710.75 ns | 2,004.68 ns | 25,100.0 ns |     400 B |
    | GetChildren_Raw | 18,692.0 ns | 606.99 ns | 1,661.62 ns | 18,500.0 ns |     400 B |
    

    Notes:

    • The optimization that List<T>.GetEnumerator() returns a mutable struct which the .NET runtime will notice and use has been around since .NET 2.

    • If one upcasts a List<T> to an IList<T> the optimization will be lost and the list enumerator will be boxed.

    • However LINQ avoids this boxing by checking whether incoming enumerables are actually lists then wrapping their enumerators in a specialized containing enumerator that declares the list enumerator explicitly, see e.g. Enumerable.Where(). If you are frequently passing the enumerables returned by GetParents() and GetChildren() into some LINQ expressions you may want to do something similar.

    • I used a single struct for both GetParents() and GetChildren() by passing in static delegates for the From and To logic. For absolutely maximal performance, if you wanted to avoid the delegate callback you could separate enumerators for each.

    • For another performance tweak, you may want to investigate whether avoiding use of the readonly keyword for fields results in a speedup. For why this might be true see Micro-optimization: the surprising inefficiency of readonly fields by Jon Skeet.