Search code examples
f#fodycontinuewithasync-workflow

F# Async Equivalent of Task.ContinueWith


I have been implementing a [<Trace>] attribute for some of our larger .NET solutions that will allow configurable analytics to be easily added to any functions/methods that are considered important. I'm using Fody and the MethodBoundaryAspect to intercept the entry and exit of each function and record metrics. This works well for synchronous functions, and for methods that return Task there is a workable solution with Task.ContinueWith, but for F# Async-returning functions, the OnExit from the MethodBoundaryAspect runs as soon as the Async is returned (rather than when the Async is actually executed).

In order to capture correct metrics for F# Async-returning functions, I was trying to come up with an equivalent solution to using Task.ContinueWith, but the closest thing I could think of was to create a new Async that binds the first one, runs the metric-capturing functions, and then returns the original result. This is further complicated by the fact that the F# Async return value I'm intercepting is presented only as an obj, and I have to do everything thereafter reflectively, as there is no non-generic version of Async like there is with Task that I can use without knowing the exact return type.

My best solution so far looks roughly like this:

open System
open System.Diagnostics
open FSharp.Reflection
open MethodBoundaryAspect.Fody.Attributes

[<AllowNullLiteral>]
[<AttributeUsage(AttributeTargets.Method ||| AttributeTargets.Property, AllowMultiple = false)>]
type TraceAttribute () =
    inherit OnMethodBoundaryAspect()

    let traceEvent (args: MethodExecutionArgs) (timestamp: int64) =
        // Capture metrics here
        ()

    override __.OnEntry (args) =
        Stopwatch.GetTimestamp() |> traceEvent args

    override __.OnExit (args) =
        let exit () = Stopwatch.GetTimestamp() |> traceEvent args
        match args.ReturnValue with
        | :? System.Threading.Tasks.Task as task ->
            task.ContinueWith(fun _ -> exit()) |> ignore             
        | other -> // Here's where I could use some help
            let clrType = other.GetType()
            if clrType.IsGenericType && clrType.GetGenericTypeDefinition() = typedefof<Async<_>> then
                // If the return type is an F# Async, replace it with a new Async that calls exit after the original return value is computed
                let returnType = clrType.GetGenericArguments().[0]
                let functionType = FSharpType.MakeFunctionType(returnType, typedefof<Async<_>>.MakeGenericType([| returnType |]))
                let f = FSharpValue.MakeFunction(functionType, (fun _ -> exit(); other))
                let result = typeof<AsyncBuilder>.GetMethod("Bind").MakeGenericMethod([|returnType; returnType|]).Invoke(async, [|other; f|]) 
                args.ReturnValue <- result
            else
                exit()

Unfortunately, this solution is not only quite messy, but I believe the reflective construction of an Async computation is adding a non-trivial amount of overhead, especially when I'm trying to trace functions that are called in a loop or have deeply-nested Async calls. Is there a better way to achieve the same result of running a given function immediately after an Async computation is actually evaluated?


Solution

  • Following @AMieres advice, I was able to update my OnExit method to correctly trace the async execution without so much overhead. I think the bulk of the problem was actually in using the same instance of AsyncBuilder, which resulted in extra invocations of the async functions. Here's the new solution:

    open System
    open System.Diagnostics
    open FSharp.Reflection
    open MethodBoundaryAspect.Fody.Attributes
    
    [<AllowNullLiteral>]
    [<AttributeUsage(AttributeTargets.Method ||| AttributeTargets.Property, AllowMultiple = false)>]
    type TraceAttribute () =
        inherit OnMethodBoundaryAspect()
        static let AsyncTypeDef = typedefof<Async<_>>
        static let Tracer = typeof<TraceAttribute>
        static let AsyncTracer = Tracer.GetMethod("TraceAsync")
    
        let traceEvent (args: MethodExecutionArgs) (timestamp: int64) =
            // Capture metrics here
            ()
    
        member __.TraceAsync (asyncResult: Async<_>) trace =
            async {
                let! result = asyncResult
                trace()
                return result
            }
    
        override __.OnEntry (args) =
            Stopwatch.GetTimestamp() |> traceEvent args
    
        override __.OnExit (args) =
            let exit () = Stopwatch.GetTimestamp() |> traceEvent args
            match args.ReturnValue with
            | :? System.Threading.Tasks.Task as task ->
                task.ContinueWith(fun _ -> exit()) |> ignore             
            | other -> 
                let clrType = other.GetType()
                if clrType.IsGenericType && clrType.GetGenericTypeDefinition() = AsyncTypeDef then
                    let generics = clrType.GetGenericArguments()
                    let result = AsyncTracer.MakeGenericMethod(generics).Invoke(this, [| other; exit |])
                    args.ReturnValue <- result
                else
                    exit()
    

    This seems to correctly trace the Async functions with significantly less overhead. I did want to trace the total time from when the function was called rather than when the async actually started, so I left my OnEntry implementation the same.