Search code examples
c#concurrencyconcurrentdictionary

Update value in ConcurrentDictionary by condition


I have class:

public class LazyConcurrentDictionary<TKey, TValue>
{
    private readonly ConcurrentDictionary<TKey, Lazy<TValue>> _concurrentDictionary;
    public LazyConcurrentDictionary()
    {
        _concurrentDictionary = new ConcurrentDictionary<TKey, Lazy<TValue>>();
    }

    public TValue GetOrAdd(TKey key, Func<TKey, TValue> valueFactory)
    {
        var lazyResult = _concurrentDictionary.GetOrAdd(key,
            k => new Lazy<TValue>(() => valueFactory(k), LazyThreadSafetyMode.ExecutionAndPublication));
        return lazyResult.Value;
    }
}

I want use this class for manage token. A token expired in 2 hours.

var dictionary = new LazyConcurrentDictionary<string, Lazy<Tuple<string, DateTime>>>();


var result = dictionary.GetOrAdd("80", k => new Lazy<Tuple<string, DateTime>>(() =>
{
    return new Tuple<string, DateTime>(Guid.NewGuid().ToString(), DateTime.Now);
}));

For example for key 80 have value tuple with token and datetime now.
How do change code when time expired get new token? best regards


Solution

  • Getting an access token generally implies a network call. In this scenario asynchronous programming is a win. So let's define a delegate that returns an access token asynchronously for a given key:

    public record class AccessToken(string TokenValue, DateTime ExpiresAt);
    
    public delegate Task<AccessToken> TokenFactoryAsyncFunction(string key, CancellationToken cancellationToken);
    

    Then you can use a concurrent dictionary to cache the tokens. However you cannot store the tokens in the dictionary because the factory is asynchronous. You need to store Task<AccessToken>. Futhermore, to avoid cache stampede because ConcurrentDictionary.GetOrAdd may invoke serveral times the dictionary value factory function, you need to use Lazy values. So, at the end, you need to store Lazy<Task<AccessToken>> values in the dictionary.

    With this in mind you can use the following AccessTokenCache class that caches the tokens and removes expired ones at intervals:

    public sealed class AccessTokenCache : IAsyncDisposable
    {
        private readonly TokenFactoryAsyncFunction tokenFactory;
        private readonly ConcurrentDictionary<string, Lazy<Task<AccessToken>>> tokenCache = new();
        private readonly TimeSpan expiredTokensRemovalCheckInterval;
        private readonly Task? expiredTokensRemovalTask;
        private readonly CancellationTokenSource cts;
    
        public AccessTokenCache(TokenFactoryAsyncFunction tokenFactory, TimeSpan expiredTokensRemovalCheckInterval)
        {
            if (tokenFactory == null)
            {
                throw new ArgumentNullException(nameof(tokenFactory));
            }
            if (expiredTokensRemovalCheckInterval != Timeout.InfiniteTimeSpan && expiredTokensRemovalCheckInterval <= TimeSpan.Zero)
            {
                throw new ArgumentException($"invalid value for {nameof(expiredTokensRemovalCheckInterval)}", nameof(expiredTokensRemovalCheckInterval));
            }
            this.tokenFactory = tokenFactory;
            this.cts = new CancellationTokenSource();
            this.expiredTokensRemovalCheckInterval = expiredTokensRemovalCheckInterval;
            if (expiredTokensRemovalCheckInterval > TimeSpan.Zero)
            {
                this.expiredTokensRemovalTask = this.RemoveExpiredTokensAtIntervals(this.cts.Token);
            }
        }
    
        private async Task RemoveExpiredTokensAtIntervals(CancellationToken cancellationToken)
        {
            while (cancellationToken.IsCancellationRequested == false)
            {
                try
                {
                    await Task.Delay(expiredTokensRemovalCheckInterval, cancellationToken);
                }
                catch (TaskCanceledException)
                {
                    return;
                }
                foreach (var kv in this.tokenCache)
                {
                    if (cancellationToken.IsCancellationRequested) return;
    
                    // better not to wait on non-completed tasks.
                    if (kv.Value.Value.IsCompleted == false) continue;
                    AccessToken accessToken;
                    try
                    {
                        accessToken = await GetAccessTokenRemovingWhenFaultedAsync(kv.Key, kv.Value);
                    }
                    catch
                    {
                        // not rethrowing intentionally.
                        continue;
                    }
                    if (accessToken.ExpiresAt <= DateTime.UtcNow)
                    {
                        this.tokenCache.TryRemove(kv);
                    }
                }
            }
        }
    
        private Lazy<Task<AccessToken>> CreateLazyAccessTokenTask(string key, CancellationToken cancellationToken)
        {
            return new Lazy<Task<AccessToken>>(async () =>
            {
                using (var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, this.cts.Token))
                {
                    return await this.tokenFactory(key, linkedCts.Token);
                }
            });
        }
    
        private Lazy<Task<AccessToken>> GetOrAdd(string key, CancellationToken cancellationToken)
        {
            return this.tokenCache.GetOrAdd(key, key => CreateLazyAccessTokenTask(key, cancellationToken));
        }
        private async Task<AccessToken> GetAccessTokenRemovingWhenFaultedAsync(string key, Lazy<Task<AccessToken>> lazyAccessTokenTask)
        {
            try
            {
                return await lazyAccessTokenTask.Value;
            }
            catch
            {
                // if you don't remove the key, getting an access token
                // for this key will fail forever.
                this.tokenCache.TryRemove(KeyValuePair.Create(key, lazyAccessTokenTask));
                throw;
            }
        }
    
        public async Task<AccessToken> GetAccessTokenAsync(string key, CancellationToken cancellationToken)
        {
            if (this.IsDisposed) throw new ObjectDisposedException(nameof(AccessTokenCache));
            var lazyAccessTokenTask = this.GetOrAdd(key, cancellationToken);
            var accessToken = await GetAccessTokenRemovingWhenFaultedAsync(key, lazyAccessTokenTask);
            if (accessToken.ExpiresAt <= DateTime.UtcNow)
            {
                var newLazyAccessTokenTask = CreateLazyAccessTokenTask(key, cancellationToken);
                // several threads can see this token as expired at the same time
                // they will try to uptate it concurrently, but only one is successful
                if (this.tokenCache.TryUpdate(key, newLazyAccessTokenTask, lazyAccessTokenTask))
                {
                    lazyAccessTokenTask = newLazyAccessTokenTask;
                }
                else
                {
                    // TryUpdate failed because another thread updated it. So let's get it
                    lazyAccessTokenTask = this.GetOrAdd(key, cancellationToken);
                }
                accessToken = await GetAccessTokenRemovingWhenFaultedAsync(key, lazyAccessTokenTask);
            }
            return accessToken;
        }
    
        public bool IsDisposed { get; private set; }
    
        public async ValueTask DisposeAsync()
        {
            if (IsDisposed) return;
            IsDisposed = true;
            using (this.cts)
            {
                this.cts.Cancel();
                if (this.expiredTokensRemovalTask != null)
                {
                    await this.expiredTokensRemovalTask;
                }
            }
        }
    }