Search code examples
c#.nettaskpollyretry-logic

Retry pattern swallows the exceptions in a Task.Run


I've got two Transient Fault Handling/Retry pattern implementations.

The issue is that theTask.Run swallows the exception and it doesn't rethrow it out of the Task.Run scope.

If I await the Task.Run it would work, but I cannot do that in my real use case.

public static class PollyRetry
{
    public static T Do<T>(Func<T> action, TimeSpan retryWait, int retryCount = 0)
    {
        var policyResult = Policy
            .Handle<Exception>()
            .WaitAndRetry(retryCount, retryAttempt => retryWait)
            .ExecuteAndCapture(action);

        if (policyResult.Outcome == OutcomeType.Failure)
        {
            throw policyResult.FinalException;
        }

        return policyResult.Result;
    }

    public static async Task<T> DoAsync<T>(Func<Task<T>> action, TimeSpan retryWait, int retryCount = 0)
    {
        var policyResult = await Policy
            .Handle<Exception>()
            .WaitAndRetryAsync(retryCount, retryAttempt => retryWait)
            .ExecuteAndCaptureAsync(action);

        if (policyResult.Outcome == OutcomeType.Failure)
        {
            throw policyResult.FinalException;
        }

        return policyResult.Result;
    }
}

public static class Retry
{
    public static void Do(Action action, TimeSpan retryInterval, int retryCount = 3)
    {
        Do<object?>(() =>
        {
            action();
            return null;
        }, retryInterval, retryCount);
    }

    public static async Task DoAsync(Func<Task> action, TimeSpan retryInterval, int retryCount = 3)
    {
        await DoAsync<object?>(async () =>
        {
            await action();
            return null;
        }, retryInterval, retryCount);
    }

    public static T Do<T>(Func<T> action, TimeSpan retryInterval, int retryCount = 3)
    {
        var exceptions = new List<Exception>();

        for (var count = 1; count <= retryCount; count++)
        {
            try
            {
                return action();
            }
            catch (Exception ex)
            {
                exceptions.Add(ex);
                if (count < retryCount)
                {
                    Thread.Sleep(retryInterval);
                }
            }
        }

        throw new AggregateException(exceptions);
    }

    public static async Task<T> DoAsync<T>(Func<Task<T>> func, TimeSpan retryInterval, int retryCount = 3)
    {
        var exceptions = new List<Exception>();

        for (var count = 1; count <= retryCount; count++)
        {
            try
            {
                return await func();
            }
            catch (Exception ex)
            {
                exceptions.Add(ex);
                if (count < retryCount)
                {
                    await Task.Delay(retryInterval);
                }
            }
        }

        throw new AggregateException(exceptions);
    }
}

public sealed class WebSocketClient
{
    private readonly Channel<string> _receiveChannel;
    private readonly Channel<string> _sendChannel;

    public WebSocketClient()
    {
        _receiveChannel = Channel.CreateBounded<string>(new BoundedChannelOptions(10)
        {
            SingleWriter = true,
            SingleReader = false,
            FullMode = BoundedChannelFullMode.DropOldest
        });

        _sendChannel = Channel.CreateBounded<string>(new BoundedChannelOptions(10)
        {
            SingleReader = true,
            SingleWriter = false,
            FullMode = BoundedChannelFullMode.Wait
        });
    }

    public async Task StartWithRetry(Uri uri)
    {
        await Retry.DoAsync(() => Task.FromResult(StartAsync(uri)), TimeSpan.FromSeconds(5), 5);
    }

    public async Task StartAsync(Uri uri)
    {
        using var ws = new ClientWebSocket();
        await ws.ConnectAsync(uri, default);

        if (ws.State == WebSocketState.Open)
        {
            const string message = "{\"op\": \"subscribe\", \"args\": [\"orderBookL2_25:XBTUSD\"]}";
            var buffer = new ArraySegment<byte>(Encoding.UTF8.GetBytes(message));
            await ws.SendAsync(buffer, WebSocketMessageType.Text, true, default);
        }

        _ = Task.Run(async () =>
        {
            while (await _receiveChannel.Reader.WaitToReadAsync())
            {
                while (_receiveChannel.Reader.TryRead(out var message))
                {
                    Console.WriteLine($"Message: {message}");
                }
            }
        });

        _ = Task.Run(async () =>
        {
            // This throws WebSocketException with ex.WebSocketErrorCode == WebSocketError.ConnectionClosedPrematurely
            while (true)
            {
                ValueWebSocketReceiveResult receiveResult;

                using var buffer = MemoryPool<byte>.Shared.Rent(4096);
                await using var ms = new MemoryStream(buffer.Memory.Length);
                do
                {
                    receiveResult = await ws.ReceiveAsync(buffer.Memory, default);

                    if (receiveResult.MessageType == WebSocketMessageType.Close)
                    {
                        break;
                    }

                    await ms.WriteAsync(buffer.Memory[..receiveResult.Count]);
                } while (!receiveResult.EndOfMessage);

                ms.Seek(0, SeekOrigin.Begin);

                if (receiveResult.MessageType == WebSocketMessageType.Text)
                {
                    using var reader = new StreamReader(ms, Encoding.UTF8);
                    var message = await reader.ReadToEndAsync();

                    await _receiveChannel.Writer.WriteAsync(message);
                }
                else if (receiveResult.MessageType == WebSocketMessageType.Close)
                {
                    break;
                }
            }
        });
    }
}

Minimal Reproducible Example

var code = new MinimalReproducibleCode();
await code.StartWithRetry();

public sealed class MinimalReproducibleCode
{
    public async Task StartWithRetry()
    {
        await Retry.DoAsync(() => Task.FromResult(StartAsync()), TimeSpan.FromSeconds(5), 5);
    }
    
    public Task StartAsync()
    {
        Console.WriteLine("This has just started");
        
        _ = Task.Run(() =>
        {
            while (true)
            {
                Console.WriteLine("Code is working");

                throw new DivideByZeroException();
            }
        });
        
        return Task.CompletedTask;
    }
}

public static class Retry
{
    public static void Do(Action action, TimeSpan retryInterval, int retryCount = 3)
    {
        _ = Do<object?>(() =>
        {
            action();
            return null;
        }, retryInterval, retryCount);
    }

    public static async Task DoAsync(Func<Task> action, TimeSpan retryInterval, int retryCount = 3)
    {
        _ = await DoAsync<object?>(async () =>
        {
            await action();
            return null;
        }, retryInterval, retryCount);
    }

    public static async Task DoAsync<TException>(
        Func<Task> action,
        Func<TException, bool> exceptionFilter,
        TimeSpan retryInterval,
        int retryCount = 3) where TException : Exception
    {
        _ = await DoAsync<object?>(async () =>
        {
            await action();
            return null;
        }, retryInterval, retryCount);
    }

    public static T Do<T>(Func<T> action, TimeSpan retryWait, int retryCount = 3)
    {
        var policyResult = Policy
            .Handle<Exception>()
            .WaitAndRetry(retryCount, retryAttempt => retryWait)
            .ExecuteAndCapture(action);

        if (policyResult.Outcome == OutcomeType.Failure)
        {
            throw policyResult.FinalException;
        }

        return policyResult.Result;
    }

    public static async Task<T> DoAsync<T>(Func<Task<T>> action, TimeSpan retryWait, int retryCount = 3)
    {
        var policyResult = await Policy
            .Handle<Exception>()
            .WaitAndRetryAsync(retryCount, retryAttempt => retryWait)
            .ExecuteAndCaptureAsync(action);

        if (policyResult.Outcome == OutcomeType.Failure)
        {
            throw policyResult.FinalException;
        }

        return policyResult.Result;
    }

    public static async Task<T> DoAsync<T, TException>(
        Func<Task<T>> action,
        Func<TException, bool> exceptionFilter,
        TimeSpan retryWait,
        int retryCount = 0) where TException : Exception
    {
        var policyResult = await Policy
            .Handle(exceptionFilter)
            .WaitAndRetryAsync(retryCount, retryAttempt => retryWait)
            .ExecuteAndCaptureAsync(action);

        if (policyResult.Outcome == OutcomeType.Failure)
        {
            throw policyResult.FinalException;
        }

        return policyResult.Result;
    }
}


Solution

  • OK, based on your code, here's how to make it work:

    public async Task StartWithRetry()
    {
        await Retry.DoAsync(() => StartAsync(), TimeSpan.FromSeconds(5), 5);
    }
    
    public async Task StartAsync()
    {
        Console.WriteLine("This has just started");
    
        await Task.Run(() =>
        {
            while (true)
            {
                Console.WriteLine("Code is working");
    
                throw new DivideByZeroException();
            }
        });
    }
    

    You need to await the Task.Run and not fire-and-forget it.

    When I run the above code I get:

    This has just started
    Code is working
    This has just started
    Code is working
    This has just started
    Code is working
    This has just started
    Code is working
    This has just started
    Code is working
    This has just started
    Code is working
    DivideByZeroException
    Attempted to divide by zero.