Search code examples
c#.netasync-awaitado.nettask

Is it possible to not open a nested ExecutionContext in an async method in dotnet?


I have a class similar to this:

class AmbientTransaction
{
    public AsyncLocal<IDbTransaction> AmbientTransaction { get; } = new();

    public IDbTransaction BeginAndSetAmbientTransaction(IDbConnection connection)
    {
        var tx = connection.BeginTransaction();
        AmbientTransaction.Value = tx;
        return tx;
    }
}

The intent is that a transaction is started and stored as the ambient transaction.

this works fine, but I want to add the method

// This method does NOT work
public async Task<IDbTransaction> BeginAndSetAmbientTransactionAsync(IDbConnection connection)
{
    var tx = await connection.BeginTransactionAsync();
    AmbientTransaction.Value = tx;
    return tx;
}

Unfortunately, this does not work because since the method is async, it gets a nested ExecutionContext, and any changes to that (which includes my AsyncLocal) will not be preserved when the method returns.

Is it possible to somehow have an async method which preserves changes to its ExecutionContext when it returns? Or is it possible to achieve what I want in any other way?


Solution

  • As the answer in the related question mentioned:

    No, not with AsyncLocal. async methods set their value context to "copy-on-write", so if it's written to, a copy will be created. And the copies never "flow" upwards.

    This is due to the whole async related infrastructure enforcing this behavior (more about in this blog post by Stephen Toub).

    So, imo we have two options. Neither of which is that great.

    Option one is to avoid the async method infrastructure by making BeginAndSetAmbientTransactionAsync still return Task<IDbTransaction> but run synchronously in its entirety. Usual disclaimers as why this is undesirable apply.:

    public Task<IDbTransaction> BeginAndSetAmbientTransactionAsync(IDbConnection connection) {
        try {
            // we block...
            var tx = connection.BeginTransactionAsync().GetAwaiter().GetResult();
            AmbientTransaction.Value = tx;
    
        } catch (Exception ex) {
            return Task.FromException<IDbTransaction>(ex);
        }
        return Task.FromResult<IDbTransaction>(tx);
    }
    

    Option two is to plug yourself into the async method infrastructure and change the behavior, but I am not sure you can do so with any confidence.

    Here is a proof-of-concept how you could achieve this by creating a custom AsyncMethodBuilder ECFlowBackMethodBuilder<T> that wraps itself around the built-in AsyncTaskMethodBuilder<T> and changes the default behaviors. Starting with .NET 6 using this custom method builder is made a bit easier by just needing to use the AsyncMethodBuilder attribute to tell the compiler to use the custom one for a particular async method.

    First place to do that is in the Start method for the synchronous case for our state machine returning immediately (I've mostly reused the code from AsyncMethodBuilderCore.Start).

    Second place, I think is appropriate is the SetResult method where we have finished moving through the state machine and we are ready to return. At this point we can capture the ExecutionContext and replace the AsyncStateMachineBox.Context with it via lots of reflection private fields (obv not good). This Context is used to run the continuation, i.e. the rest of the method that awaited us.

    This is the code which I tested on .NET6, .NET7 and .NET8. I am sure there are things I missed, but not sure that it's entirely untenable. One could perhaps use a custom awaitable, custom state machine etc, instead of reflection.

    class Program {
        async static Task Main() {
        
            asyncLocal.Value = 55;
            try {
                var result = await SpecialECTask();
    
            } catch (Exception ex) {
    
                Console.WriteLine(ex.Message);
            }
    
            Console.WriteLine("After returning from SpecialTask: " + asyncLocal.Value);
            // 111
        }
    
    
        public static AsyncLocal<int> asyncLocal = new();
    
        public static async Task<int> NormalTask() {
            //await Task.CompletedTask;
            //return 33;
    
            Console.WriteLine("In normal task");
            await Task.Delay(3000);
    
            //throw new Exception("Test Exception");
            Console.WriteLine("Value in normal task " + asyncLocal.Value);
            asyncLocal.Value = 44; // not gonna be flown back
            Console.WriteLine("Changed in normal task " + asyncLocal.Value);
    
            return 111;
        }
    
        [AsyncMethodBuilder(typeof(ECFlowBackMethodBuilder<>))]
        public static async Task<int> SpecialECTask() {
            Console.WriteLine("At beginning of SpecialECTask " + asyncLocal.Value);
    
            var result = await NormalTask(); // needs capture in my implementation
            Console.WriteLine("After awaiting Normaltask " + asyncLocal.Value);
            asyncLocal.Value = result;
            Console.WriteLine("Changed in special ec task " + asyncLocal.Value);
    
            //throw new Exception("Special Exception");
            return result;
        }
    
    }
    
    
    public struct ECFlowBackMethodBuilder<T> {
        // we reuse the AsyncTaskMethodBuilder and forward to it
        private AsyncTaskMethodBuilder<T> _builder;
        public Task<T> Task => _builder.Task;
        public static ECFlowBackMethodBuilder<T> Create() => default;
    
        public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine {
            if (stateMachine == null) {
                throw new ArgumentNullException(nameof(stateMachine));
            }
            SynchronizationContext synchronizationContext =
            SynchronizationContext.Current;
            try {
                stateMachine.MoveNext();
            } finally {
                if (synchronizationContext != SynchronizationContext.Current) {
                    SynchronizationContext.SetSynchronizationContext(synchronizationContext);
                }
            }
        }
    
        public void SetResult(T result) {
    
            var builderTask = _builder.Task;
    
            var contObjectFieldInfo = typeof(Task)
            .GetField("m_continuationObject", BindingFlags.Instance | BindingFlags.NonPublic);
    
            var cont = contObjectFieldInfo.GetValue(builderTask);
            //cont.GetType().Dump();
            if (cont is Task contStateMachineBox)
                //AsyncStateMachineBox is Task
                // internal interface 
                {
                //Console.WriteLine("RELEASE");
                ReplaceECInStateMachine(contStateMachineBox);
    
            } else if (cont is Action contAction) {
                // Debugging encapsulation
                // in ContinuationWrapper
    
                //Console.WriteLine("DEBUG");
                // ContinuationWrapper
                var contWrapperInstance = contAction.Target;
                var contWrapperContinuatinField =
                contWrapperInstance.GetType()
                    .GetFields(BindingFlags.NonPublic | BindingFlags.Instance)
                    .Where(x => x.Name == "_continuation")
                    .FirstOrDefault();
    
                if (contWrapperContinuatinField is null) {
                    throw new NotSupportedException();
                }
    
                var realContinuation =
                contWrapperContinuatinField
                .GetValue(contWrapperInstance) as Action;
                var realContinuationInstance = realContinuation.Target;
    
                ReplaceECInStateMachine(realContinuationInstance);
    
            } else if (cont is null) {
                // synchronous case
            } else if (cont is IThreadPoolWorkItem) {
                // SynchronizationContxtAwaitContinuation
                // internal so IThreadPoolWorkItem for check
    
                var m_action = cont.GetType()
                .GetField("m_action",
                BindingFlags.NonPublic | BindingFlags.Instance);
    
                var m_actionInstance = m_action.GetValue(cont) as Action;
    
                // Release
                if (m_actionInstance.Target is Task stateMachine) {
                    ReplaceECInStateMachine(stateMachine);
                    _builder.SetResult(result);
                    return;
                }
    
                // is continuation wrapper
                // DEBUG
                var contWrapperInstance = m_actionInstance.Target;
                var contWrapperContinuatinField =
                contWrapperInstance.GetType()
                    .GetFields(BindingFlags.NonPublic | BindingFlags.Instance)
                    .Where(x => x.Name == "_continuation")
                    .FirstOrDefault();
    
                if (contWrapperContinuatinField is null) {
                    throw new NotSupportedException();
                }
    
                var realContinuation =
                contWrapperContinuatinField
                .GetValue(contWrapperInstance) as Action;
                var realContinuationInstance = realContinuation.Target;
                ReplaceECInStateMachine(realContinuationInstance);
            } else {
                throw new NotSupportedException();
    
            }
    
            _builder.SetResult(result);
    
            void ReplaceECInStateMachine(object stateMachineBox) {
                if (stateMachineBox is null) {
                    throw new ArgumentNullException(nameof(stateMachineBox));
                }
    
                var currentContext = ExecutionContext.Capture();
    
                var contextFieldInfo = stateMachineBox
                    .GetType()
                    .GetTypeInfo().DeclaredFields
                    .Where(x => x.Name == "Context")
                    .FirstOrDefault();
    
                if (contextFieldInfo is not null) {
                    contextFieldInfo.SetValue(stateMachineBox, currentContext);
                    return;
                }
    
                // check for property i.e. .NET 8
                // showing how brittle things are really
                var contextPropertyInfo = stateMachineBox
                    .GetType()
                    .GetTypeInfo().DeclaredProperties
                    .Where(x => x.Name == "Context")
                    .FirstOrDefault();
    
                if (contextPropertyInfo is null) {
                    throw new NotSupportedException();
                }
    
                // ref returning Property...(reflection limitations)
                // need to modify the field it references
                // which is the m_stateObject of Task
                var ec = contextPropertyInfo
                    .GetValue(stateMachineBox);
    
                if (ec is null) {
                    // shouldn't really be the case
                    return;
                }
    
                var casted = ec as ExecutionContext;
                if (casted is null) {
                    throw new NotSupportedException();
                }
    
                var m_StateObject = typeof(Task)
                    .GetFields(
                    BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance)
                    .Where(x => x.Name == "m_stateObject")
                    .FirstOrDefault();
    
                var ecInStateObject = m_StateObject.
                     GetValue(stateMachineBox);
    
                if (casted != ecInStateObject) {
                    throw new NotSupportedException();
    
                }
    
                m_StateObject.SetValue(stateMachineBox, currentContext);
            }
    
        }
    
        // simple forwarding nothing special
        public void AwaitOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
        where TAwaiter : INotifyCompletion
        where TStateMachine : IAsyncStateMachine {
            _builder.AwaitOnCompleted(ref awaiter, ref stateMachine);
        }
    
        public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(
        ref TAwaiter awaiter, ref TStateMachine stateMachine)
            where TAwaiter : ICriticalNotifyCompletion
            where TStateMachine : IAsyncStateMachine {
            _builder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);
        }
    
        public void SetException(Exception exception) {
            _builder.SetException(exception);
        }
    
        public void SetStateMachine(IAsyncStateMachine stateMachine) {
            _builder.SetStateMachine(stateMachine);
        }
    
    }
    

    the output is:

    At beginning of SpecialECTask 55
    In normal task
    Value in normal task 55
    Changed in normal task 44
    After awaiting Normaltask 55
    Changed in special ec task 111
    After returning from SpecialTask: 111