Search code examples
c#unit-testingasync-awaitmstest

Unit testing async void (fire-and-forget) methods


I am writing tests in MSTest for a WPF application that contains calls to async void methods (fire-and-forget pattern).

My goal is to run this code synchronously during tests.

Is this possible?

I have tried with a "synchronous" SynchronizationContext:

public class SynchronousSynchronizationContext : SynchronizationContext
{
    private readonly ConcurrentQueue<(SendOrPostCallback, object)> _workItems = new ConcurrentQueue<(SendOrPostCallback, object)>();
    private readonly AutoResetEvent _workItemsAvailable = new AutoResetEvent(false);


    public override void Post(SendOrPostCallback d, object state)
    {
        Trace.WriteLine("SynchronousSynchronizationContext.Post(), start");

        // Execute the callback immediately
        d(state);

        Trace.WriteLine("SynchronousSynchronizationContext.Post(), end");
    }

    /* enqueues the work item to be completed later:
    public override void Post(SendOrPostCallback d, object state)
    {
        Trace.WriteLine("SynchronousSynchronizationContext.Post()");

        _workItems.Enqueue((d, state));
        _workItemsAvailable.Set();
    }
    */

    public override void Send(SendOrPostCallback d, object state)
    {
        Trace.WriteLine("SynchronousSynchronizationContext.Send()");

        d(state);
    }

    public void Run()
    {
        Trace.WriteLine("SynchronousSynchronizationContext.Run()");

        while (_workItems.TryDequeue(out var workItem))
        {
            workItem.Item1(workItem.Item2);
        }
    }

    public void Complete()
    {
        Trace.WriteLine("SynchronousSynchronizationContext.Complete()");

        while (_workItems.TryDequeue(out var workItem))
        {
            Trace.WriteLine("SynchronousSynchronizationContext, execute callback");

            // execute callback.
            // The workItem is a tuple where Item1 is the callback and Item2 is the state object.
            workItem.Item1(workItem.Item2);
        }
    }
}

But it does not work. The Post() method does not run the callback synchronously like I want.

Instead I have to call Complete() explicitly at various points in my test method:

private SynchronousSynchronizationContext _syncContext;

[TestInitialize]
public void TestInitialize()
{    
    _syncContext = new SynchronousSynchronizationContext();
    SynchronizationContext.SetSynchronizationContext(_syncContext);
}

[TestMethod]
public void MyTest_ReturnsData()
{
    // ACT
    MyTest();

    // Finish all fire-and-forget tasks:
    _syncContext.Complete();

    // ASSERT
    ...

}

Is there a way to run the code synchronously without having to call Complete()?


Solution

  • We could rely on the async void state machine calling SynchronizationContext.OperationStarted and SynchronizationContext.OperationCompleted to track the async operation and catch exceptions. As Servy commented, this is a bit challenging because the async void method can itself call async void methods in its body, but I think we can get it working with simply reusing the same SynchronizationContext and maintaining a count of active operations.

    To simplify the use we wrap everything in a Task returning static method call such as:

    await AsyncVoidRunner.RunAsync(AsyncVoidMethod);
    

    This will wait for the main async void method (and all nested) to complete and will capture all exceptions and rethrow either a single or an AggregateException with the correct stack trace.

    Also, I think it is desirable to be able to specify a SynchronizationContext such as the default WPF or a custom one that we wrap around but let it do its job otherwise.

    This is a possible implementation:

    public static class AsyncVoidRunner {
    
        public static async Task RunAsync(
           Action asyncVoidMethod,
           bool continuationsOnCurrentContext = true) {
    
            var capturedContext = SynchronizationContext.Current;
            var asyncSyncContext = new AsyncVoidSyncContext(continuationsOnCurrentContext ?
            capturedContext : null);
    
            // we have to set it before we call the asyncVoidMethod
            SynchronizationContext.SetSynchronizationContext(asyncSyncContext);
    
            try {
                asyncVoidMethod();
    
                // fake async to avoid possible dead-locks
                await Task.Run(() => asyncSyncContext.WaitForCompletion());
            } finally {
                SynchronizationContext.SetSynchronizationContext(capturedContext);
            }
        }
    
        public static async Task RunAsync(
        Action asyncVoidMethod,
        SynchronizationContext syncContext) {
    
            ArgumentNullException.ThrowIfNull(syncContext);
    
            var capturedContext = SynchronizationContext.Current;
            var asyncSyncContext = new AsyncVoidSyncContext(syncContext);
    
            SynchronizationContext.SetSynchronizationContext(asyncSyncContext);
    
            try {
                asyncVoidMethod();
                // fake async to avoid possible dead-locks
                // TODO: possible rewrite of WaitForCompletion to be async
                await Task.Run(() => asyncSyncContext.WaitForCompletion());
            } finally {
                SynchronizationContext.SetSynchronizationContext(capturedContext);
            }
        }
    
        private sealed class AsyncVoidSyncContext : SynchronizationContext {
    
    
            List<ExceptionDispatchInfo> _exceptionsInfos = new List<ExceptionDispatchInfo>();
            SynchronizationContext _mainContext;
            
            ManualResetEventSlim _operationStartedCalled = new ManualResetEventSlim(false);
            int _activeOps;
            ManualResetEventSlim _lastOperationCompleted = new ManualResetEventSlim(false);
            int _pendingPosts;
            ManualResetEventSlim _noPendingPosts = new ManualResetEventSlim(true);
    
            public AsyncVoidSyncContext(SynchronizationContext syncContext) {
                _mainContext = syncContext;
            }
    
            // async void state machine is calling this when it starts execution
            public override void OperationStarted() {
                Console.WriteLine("Operation Started on thread: " + Thread.CurrentThread.ManagedThreadId);
    
                Interlocked.Increment(ref _activeOps);
    
                // for non-async void methods check
                // to avoid endless blocking in WaitForCompletion()
                if (_operationStartedCalled.IsSet == false) {
                    _operationStartedCalled.Set();
                }
            }
    
            // async void state machine is calling this when finished
            public override void OperationCompleted() {
                Console.WriteLine("OperationCompleted on thread: " + Thread.CurrentThread.ManagedThreadId);
    
                if (Interlocked.Decrement(ref _activeOps) == 0) {
                
                    // now at this point we could still have 
                    // oustanding Post that throws exception
                    // becasuse SetException does THrowAsync
                    // which is Post
                    // then becaues it returns immediately
                    // it calls OperationCompleted...
                    // even though we haven't had the chance of processing everything
                    // -> we maintain a counter and event for outstandings posts
                    // which we check in WaitForCompletion
                    _lastOperationCompleted.Set();
                }
            }
    
            public void WaitForCompletion() {
                // short wait for initiation
                // if not started not async void method
                if (_operationStartedCalled.Wait(10) == false) {
                    throw new Exception("Non void method");
                    // TODO: or just return?
                };
    
                _lastOperationCompleted.Wait();
                // state machine is Posting the Exceptions Async
                // and marking Complete before we have a change to process
                // all posts (relevant if we use another SyncContext)
                _noPendingPosts.Wait();
                
                switch (_exceptionsInfos.Count) {
                    case 0: return;
                    case 1:
                        // TODO: just throw ??
                        _exceptionsInfos[0].Throw();
                        break;
                    case > 1:
                        var aggregateException = new AggregateException(
                        _exceptionsInfos.Select(e => e.SourceException));
                        throw aggregateException;
                }
            }
    
            // just dispatching to PostInternal (see comment there)
            public override void Post(SendOrPostCallback d, object state) {
                Interlocked.Increment(ref _pendingPosts);
                _noPendingPosts.Reset();
                
                if (_mainContext == null) {
                    // we execute main logic in-line
                    PostInternal(state, d, this);
                } else {
                    // more transparent capture (instead of inline lambda)
                    // we let the main Context execute the main-logic
                    // possibly on another thread (e.g. UI thread)
                    
                    // chcek if outstanding
                    _mainContext.Post(this.PostCaptureHelper, (state, d));
                }
            }
    
            void PostCaptureHelper(object valueTUple) {
                switch (valueTUple) {
                    case (object s, SendOrPostCallback d):
                        PostInternal(s, d, this);
                        return;
                    default: throw new InvalidOperationException();
                }
            }
    
            // we are possibly called on a thread pool thread
            // coming back from an async operation (i.e. Task.Delay)
            // we need to set the SynchronizationContext to this
            // before we continue
            // so it is captured in the state machines of other async methods 
            // we call from this point forward
            /*  await Task.Delay(100);
                // we are here with no SyncContext to capture
                // if we don't do that explicitly
                asyncMethod();
            */
            // in the current async method / state machine
            static void PostInternal(object state, SendOrPostCallback d,
              AsyncVoidSyncContext asyncVoidContext) {
    
                var captureContext = SynchronizationContext.Current;
                try {
                    SynchronizationContext.SetSynchronizationContext(asyncVoidContext);
                    d(state);
                } catch (Exception ex) {
                    asyncVoidContext._exceptionsInfos.Add(ExceptionDispatchInfo.Capture(ex));
                } finally {
                    SynchronizationContext.SetSynchronizationContext(captureContext);
                    if(Interlocked.Decrement(ref asyncVoidContext._pendingPosts)==0){
                        asyncVoidContext._noPendingPosts.Set();
                    }
    
                }
            }
        }
    }
    

    To test we could use the the following methods:

    async void AsyncVoidMethod() {
        NestedAsyncVoid();
        Console.WriteLine("AsyncVoidMethod on thread: " + Thread.CurrentThread.ManagedThreadId);
        await Task.Delay(250);
        Console.WriteLine("AsyncVoidMethod on thread: " + Thread.CurrentThread.ManagedThreadId);
        NestedAsyncVoid(); 
        throw new Exception("---AsyncVoidMethod EXCEPTION"); 
    }
    
    async void NestedAsyncVoid() {
        await Task.Delay(100);
        MoreNestedAsyncVoid(); 
        await Task.Delay(100);
        Console.WriteLine("NestedAsyncVoid on thread: " + Thread.CurrentThread.ManagedThreadId);
        throw new Exception("----NestedAsyncVoid EXCEPTION"); 
    }
    
    async void MoreNestedAsyncVoid() {
        await Task.Delay(200);
        Console.WriteLine("MoreNestedAsyncVoid on thread: " + Thread.CurrentThread.ManagedThreadId);
        throw new Exception("-----MoreNestedAsyncVoid EXCEPTION");
    }
    

    Two obvious scenarios are with/without another context. For simulation purposes I've picked the SingleThreadSynchronizationContext which is demonstrated in this blogpost by Stephen Toub

    async Task SingleThreadSyncContextTest() {
        var capturedContext = SynchronizationContext.Current;
        try {
            var syncCtx = new SingleThreadSynchronizationContext();
            SynchronizationContext.SetSynchronizationContext(syncCtx);
    
            var t = AsyncVoidRunner.RunAsync(AsyncVoidMethod);
            var continueTask = t.ContinueWith(
                t => {
                    syncCtx.Complete();
                }, TaskScheduler.Default);
    
            syncCtx.RunOnCurrentThread();
            await t;
    
        } catch (Exception ex) {
            if (ex is AggregateException aggEx) {
                foreach (var element in aggEx.InnerExceptions) {
                    Console.WriteLine(element.Message);
                }
                return;
            }
            Console.WriteLine(ex.Message);
        } finally {
            SynchronizationContext.SetSynchronizationContext(capturedContext);
        }
    }
    
    async Task NoSyncContext() {
        try {
            SynchronizationContext.SetSynchronizationContext(null);
            await AsyncVoidRunner.RunAsync(AsyncVoidMethod, true);
        } catch (Exception ex) {
            if (ex is AggregateException aggEx) {
                foreach (var element in aggEx.InnerExceptions) {
                    Console.WriteLine(element.Message);
                }
                return;
            }
            Console.WriteLine(ex.Message);
        }
    }
    

    Finally:

    await SingleThreadSyncContextTest();
    Console.WriteLine();
    await NoSyncContext();
    

    and the results:

    Operation Started on thread: 1
    Operation Started on thread: 1
    AsyncVoidMethod on thread: 1
    Operation Started on thread: 1
    NestedAsyncVoid on thread: 1
    OperationCompleted on thread: 1
    AsyncVoidMethod on thread: 1
    Operation Started on thread: 1
    OperationCompleted on thread: 1
    MoreNestedAsyncVoid on thread: 1
    OperationCompleted on thread: 1
    Operation Started on thread: 1
    NestedAsyncVoid on thread: 1
    OperationCompleted on thread: 1
    MoreNestedAsyncVoid on thread: 1
    OperationCompleted on thread: 1
    ----NestedAsyncVoid EXCEPTION
    ---AsyncVoidMethod EXCEPTION
    -----MoreNestedAsyncVoid EXCEPTION
    ----NestedAsyncVoid EXCEPTION
    -----MoreNestedAsyncVoid EXCEPTION
    
    Operation Started on thread: 1
    Operation Started on thread: 1
    AsyncVoidMethod on thread: 1
    Operation Started on thread: 25
    NestedAsyncVoid on thread: 25
    OperationCompleted on thread: 25
    AsyncVoidMethod on thread: 15
    Operation Started on thread: 15
    OperationCompleted on thread: 15
    MoreNestedAsyncVoid on thread: 25
    OperationCompleted on thread: 25
    Operation Started on thread: 25
    NestedAsyncVoid on thread: 27
    OperationCompleted on thread: 27
    MoreNestedAsyncVoid on thread: 25
    OperationCompleted on thread: 25
    ----NestedAsyncVoid EXCEPTION
    ---AsyncVoidMethod EXCEPTION
    -----MoreNestedAsyncVoid EXCEPTION
    ----NestedAsyncVoid EXCEPTION
    -----MoreNestedAsyncVoid EXCEPTION