Search code examples
c#functional-programmingmemoization

Implementing memoization in C#


I know this topic (memoization) has been discussed quite a bit (like here), but no answer I could find satisfies the DRY principle as much as I would like, so please, read this whole question and the three points I want to address.

I have a simple support class like this:

public class Memoized<T1, TResult>
{
    private readonly Func<T1, TResult> _f;

    private readonly Dictionary<T1, TResult> _cache = new Dictionary<T1, TResult>();

    public Memoized(Func<T1, TResult> f)
    {
        _f = f;
    }

    public TResult Invoke(T1 p1)
    {
        if (p1 == null) throw new ArgumentNullException(nameof(p1));

        if (_cache.TryGetValue(p1, out var res)) return res;

        return _cache[p1] = _f(p1);
    }
    public static Func<T1, TResult> Of(Func<T1, TResult> f)
    {
        var memo = new Memoized<T1, TResult>(f);

        return x => memo.Invoke(x);
    }
}

Nothing especially fancy, but it allows me to do this:

public class MyClass
{
    public Func<int, bool> MemoizedMethod { get; }

    private bool UncachedMethod(int v)
    {
        return v > 0;
    }

    public MyClass()
    {
        MemoizedMethod = Memoized<int, bool>.Of(UncachedMethod);
    }
}

Now, even if the resulting code is not terribly noisy, I'm trying to figure out if the implementation could be more elegant, because currently I need:

  1. an invokable property that acts as a Method.
  2. a true method that should not be called directly.
  3. a line in the constructor (cannot be an inline initialization) that links the two (with a third repetition of the function signature!).

Any suggestion for a strategy that allows to remove one (or two!) of the above points would be great.


Solution

  • In my struggle for elegance, I finally found what I think is the best syntax I saw anywhere:

    private class MemoizedTest
    {
        private int _counter = 0;
    
        public int Method(int p) => this.Memoized(p, x =>
        {
            return _counter += x;
        });
    }
    

    Implementation (one pretty small extension class):

    namespace System
    {
        public static class MemoizerExtension
        {
            internal static ConditionalWeakTable<object, ConcurrentDictionary<string, object>> _weakCache =
                new ConditionalWeakTable<object, ConcurrentDictionary<string, object>>();
    
            public static TResult Memoized<T1, TResult>(
                this object context,
                T1 arg,
                Func<T1, TResult> f,
                [CallerMemberName] string? cacheKey = null)
                where T1 : notnull
            {
                if (context == null) throw new ArgumentNullException(nameof(context));
                if (cacheKey == null) throw new ArgumentNullException(nameof(cacheKey));
    
                var objCache = _weakCache.GetOrCreateValue(context);
    
                var methodCache = (ConcurrentDictionary<T1, TResult>) objCache
                    .GetOrAdd(cacheKey, _ => new ConcurrentDictionary<T1, TResult>());
    
                return methodCache.GetOrAdd(arg, f);
            }
        }
    }
    

    Explanation In the implementation I'm using a ConditionalWeakTable for caching, effectively extending the internal structure of the object invoking the memoization. As an additional key, the CallerMemberName is used, acting as a second key (this allows more memoizations for instance, and optionally more memoizations per method, if passing the cacheKey parameter explicitly). The third key is the parameter of the invocation.

    So, we have 3 runtime dictionary-like searches instead of 1, but a syntax that is a lot cleaner, IMO.

    Is it worth it? I dunno, but my desire for elegance is satiated.

    If someone else is interested, I'm including the tests for reference:

    [TestFixture]
    public class MemoizerTest
    {
        [Test]
        public void MemoizationWorksOnFuncs()
        {
            int counter = 0;
    
            Func<int, int> f = x => counter += x;
    
            Assert.That(this.Memoized(1, f), Is.EqualTo(1));
    
            Assert.That(this.Memoized(2, f), Is.EqualTo(3));
    
            Assert.That(this.Memoized(2, f), Is.EqualTo(3));
    
            Assert.That(this.Memoized(1, f), Is.EqualTo(1));
        }
    
        private class MemoizedTest
        {
            private int _counter = 0;
    
            public int Method(int p)
                => this.Memoized(p, x => { return _counter += x; });
        }
    
        [Test]
        public void MemoizationWorksOnInstances()
        {
            var obj1 = new MemoizedTest();
    
            Assert.That(obj1.Method(5), Is.EqualTo(5));
            Assert.That(obj1.Method(4), Is.EqualTo(9));
            Assert.That(obj1.Method(5), Is.EqualTo(5));
            Assert.That(obj1.Method(1), Is.EqualTo(10));
            Assert.That(obj1.Method(4), Is.EqualTo(9));
    
            obj1 = new MemoizedTest();
    
            Assert.That(obj1.Method(5), Is.EqualTo(5));
            Assert.That(obj1.Method(4), Is.EqualTo(9));
            Assert.That(obj1.Method(5), Is.EqualTo(5));
            Assert.That(obj1.Method(1), Is.EqualTo(10));
            Assert.That(obj1.Method(4), Is.EqualTo(9));
        }
    
        [Test]
        [Ignore("This test passes only when compiled in Release mode")]
        public void WeakMemoizationCacheIsCleared()
        {
            var obj1 = new MemoizedTest();
    
            var r1 = obj1.Method(5);
    
            MemoizerExtension._weakCache.TryGetValue(obj1, out var cache);
    
            var weakRefToCache = new WeakReference(cache);
    
            cache = null;
            GC.Collect(2);
    
            obj1 = null;
    
            GC.Collect();
            GC.Collect();
    
            var msg = weakRefToCache.TrackResurrection;
    
            Assert.That(weakRefToCache.IsAlive, Is.False, "The weak reference should be dead.");
    
            Assert.That(r1, Is.EqualTo(5));
        }
    }