Search code examples
c#async-awaittasksynchronizationcontext

Use SynchronizationContext only on Task RanToCompletion


The code below explains the idea

    private async void button1_Click(object sender, EventArgs e)
    {
        string result;
        CancellationToken cancellationToken = new CancellationTokenSource(3000).Token;

        try
        {
            result = await GetDataAsync(cancellationToken)
                .ContextIfSuccess();    // Should use SynchronizationContext only if Task status is RanToCompletion
        }
        catch(OperationCanceledException)
        {
            /* Context is not required */
            return;
        }
        catch (Exception ex)
        {
            /* Context is not required otherwise it can slow down UI Thread a little bit */
            Log(ex.ToString());
            return;
        }

        /* UI Thread only */
        button1.Text = result;
    }

The question is "Is it possible to make the method like ContextIfSuccess() ?"


Solution

  • In order to have the method that you want you'll need to create a custom awaiter. It's mostly boilerplate, the key is simply that, when asked to add a continuation, you add one using the current sync context to run on successful completion, and one using the default scheduler when it doesn't run to completion.

    public struct CaptureContextOnSuccessAwaiter : INotifyCompletion
    {
        private Task task;
    
        public CaptureContextOnSuccessAwaiter(Task task)
        {
            this.task = task;
        }
    
        public CaptureContextOnSuccessAwaiter GetAwaiter() { return this; }
    
        public void OnCompleted(Action continuation)
        {
            if (SynchronizationContext.Current != null)
            {
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.OnlyOnRanToCompletion,
                    TaskScheduler.FromCurrentSynchronizationContext());
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.NotOnRanToCompletion,
                    TaskScheduler.Default);
            }
            else
            {
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.None,
                    TaskScheduler.Default);
            }
        }
    
        public void GetResult() { task.GetAwaiter().GetResult(); }
        public bool IsCompleted { get { return task.GetAwaiter().IsCompleted; } }
    }
    
    public struct CaptureContextOnSuccessAwaiter<T> : INotifyCompletion
    {
        private Task<T> task;
    
        public CaptureContextOnSuccessAwaiter(Task<T> task)
        {
            this.task = task;
        }
    
        public CaptureContextOnSuccessAwaiter<T> GetAwaiter() { return this; }
    
        public void OnCompleted(Action continuation)
        {
            if (SynchronizationContext.Current != null)
            {
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.OnlyOnRanToCompletion,
                    TaskScheduler.FromCurrentSynchronizationContext());
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.NotOnRanToCompletion,
                    TaskScheduler.Default);
            }
            else
            {
                task.ContinueWith(t => continuation(),
                    CancellationToken.None,
                    TaskContinuationOptions.None,
                    TaskScheduler.Default);
            }
        }
    
        public T GetResult() { return task.GetAwaiter().GetResult(); }
        public bool IsCompleted { get { return task.GetAwaiter().IsCompleted; } }
    }
    
    public static CaptureContextOnSuccessAwaiter ContextIfSuccess(this Task task)
    {
        return new CaptureContextOnSuccessAwaiter(task);
    }
    
    public static CaptureContextOnSuccessAwaiter<T> ContextIfSuccess<T>(this Task<T> task)
    {
        return new CaptureContextOnSuccessAwaiter<T>(task);
    }