Search code examples
c#asp.net-core

Execute request model binding manually


In Asp.Net you can automatically parse request data into an input model / API contract using for example the attributes FromBody, FromQuery, and FromRoute. I want to execute this behavior myself. Let me explain.

I want to have a custom policy requirement based on a combination of data passed to the requirement and the target entity which is passed inside the request data. But this target entity id can be in different locations. Usually the body, but for example the route or the query when using HttpGet. So I thought about putting this information about the location above the controller endpoint using an attribute. The following pseudo-code is based on the guess that I need the BindingSource.

I would create API contracts using an interface defining the location of the target id.

public interface ITargetEntityContract {
    public string TargetEntityId { get; set; }
}

public class ExampleRequest : ITargetEntityContract {
    public string TargetEntityId { get; set; } = default!;
    public string SomeOtherData { get; set; } = default!;
}

Then I would create an attribute to define the location:

[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method)]
public class TargetEntityLocationAttribute : Attribute {
    public Type ContractType { get; }
    public BindingSource BindingSource { get; }

    public TargetEntityLocationAttribute(Type contractType, BindingSource bindingSource) {
        if (!typeof(ITargetEntityContract).IsAssignableFrom(contractType))
            throw new Exception("Contract has to implement the interface ITargetEntityContract");

        this.ContractType = contractType;
        this.BindingSource = bindingSource;
    }
}

And you would apply this onto a controller endpoint the following way:

[TargetEntityLocation(typeof(ExampleRequest), BindingSource.Body)]
public async Task<IActionResult> SomeEndpointAsync([FromBody] ExampleRequest requestData) {
    
}

Within the IAuthorizationHandler I would use these classes the following way:

protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, ExampleAuthorizationRequirement requirement) {
    var endpoint = this._httpContextAccessor.HttpContext!.GetEndpoint();

    if (endpoint is null)
        throw new Exception("Some error");

    var targetEntityLocation = endpoint.MetaData.GetMetaData<TargetEntityLocationAttribute>();

    if (targetEntityLocation is null)
        throw new Exception("some error");
    
    var targetEntityModel = SomeAlmightyParser.ParseFromSource(this._httpContextAccessor.HttpContext!, targetEntityLocation.ContractType, targetEntityLocation.BindingSource) as ITargetEntityContract;
    
    // do something with targetEntityModel.TargetEntityId
}

Is there a way to parse the request data of the HttpContext into the given model based on the data location?


Solution

  • It's been a few years now, but here is the answer I came up with. NOTE: This code only works when executed within a project using the Sdk Microsoft.NET.Sdk.Web since the ModelBindingMelper's namespace is Microsoft.AspNetCore.Mvc.ModelBinding in the Web-Sdk but Microsoft.AspNetCore.Mvc.ModelBinding.Internal in the normal Sdk.

    Interface:

    public interface IRequestModelBinder {
        public class Result<TRequestModel> {
            public bool Succeeded { get; set; }
            public TRequestModel? Model { get; set; }
            public ModelStateDictionary ModelState { get; set; }
        }
        
        Task<Result<TRequestModel>> TryBindModelAsync<TRequestModel>() where TRequestModel : class, new();
        Task<Result<TRequestModel>> TryBindModelAsync<TRequestModel>(string prefix) where TRequestModel : class, new();
    }
    

    Implementation:

    public class RequestModelBinder : IRequestModelBinder {
        public static readonly MethodInfo TryUpdateModelAsyncGenericMethodDefinition;
        
        protected IHttpContextAccessor HttpContextAccessor { get; }
        protected IApiDescriptionGroupCollectionProvider ApiDescriptionProvider { get; }
        protected MvcOptions Options { get; }
    
        static RequestModelBinder() {
            TryUpdateModelAsyncGenericMethodDefinition =
                Type.GetType("Microsoft.AspNetCore.Mvc.ModelBinding.ModelBindingHelper, Microsoft.AspNetCore.Mvc.Core")!
                    .GetMethods(BindingFlags.Static | BindingFlags.Public)
                    .First(methodInfo =>
                        methodInfo.Name == "TryUpdateModelAsync" &&
                        methodInfo.IsGenericMethod &&
                        methodInfo.GetGenericArguments().Length == 1 &&
                        methodInfo.GetParameters().Length == 7
                    );
        }
        
        public RequestModelBinder(IHttpContextAccessor httpContextAccessor, IApiDescriptionGroupCollectionProvider apiDescriptionProvider, IOptions<MvcOptions> options) {
            this.HttpContextAccessor = httpContextAccessor;
            this.ApiDescriptionProvider = apiDescriptionProvider;
            this.Options = options.Value;
        }
    
        public async Task<IRequestModelBinder.Result<TRequestModel>> TryBindModelAsync<TRequestModel>() where TRequestModel : class, new() {
            return await this.TryBindModelAsync<TRequestModel>(string.Empty);
        }
    
        public async Task<IRequestModelBinder.Result<TRequestModel>> TryBindModelAsync<TRequestModel>(string prefix) where TRequestModel : class, new() {
            var actionDescriber = this.HttpContextAccessor.HttpContext!.Features.Get<IEndpointFeature>()?.Endpoint.Metadata.GetMetadata<ControllerActionDescriptor>()!;
            var actionContext = new ActionContext(
                this.HttpContextAccessor.HttpContext!,
                this.HttpContextAccessor.HttpContext.GetRouteData(),
                actionDescriber
            );
    
            var controllerContext = new ControllerContext(actionContext) {
                ValueProviderFactories = this.Options.ValueProviderFactories
            };
    
            try {
                var valueProvider = await CompositeValueProvider.CreateAsync(controllerContext, controllerContext.ValueProviderFactories);
                var metaDataProvider = this.HttpContextAccessor.HttpContext.RequestServices.GetRequiredService<IModelMetadataProvider>();
                var modelBinderFactory = this.HttpContextAccessor.HttpContext.RequestServices.GetRequiredService<IModelBinderFactory>();
                var objectValidator = this.HttpContextAccessor.HttpContext.RequestServices.GetRequiredService<IObjectModelValidator>();
                
                /////////////////Try binding body source////////////////////////
                if (actionDescriber.Parameters.FirstOrDefault(p => p.ParameterType == typeof(TRequestModel))?.BindingInfo.BindingSource == BindingSource.Body) {
                    var apiDescriptions = this.ApiDescriptionProvider.ApiDescriptionGroups.Items
                        .SelectMany(group => group.Items)
                        .ToList();
                    
                    var actionApiDescription = apiDescriptions
                        .First(description => description.ActionDescriptor.DisplayName == actionDescriber.DisplayName);
    
                    var actionContentType = new ContentType(actionContext.HttpContext.Request.ContentType);
                    var apiFormat = actionApiDescription.SupportedRequestFormats
                        .First(format => format.MediaType == actionContentType.MediaType);
    
                    var streamPosition = actionContext.HttpContext.Request.Body.Position;
                    if (actionContext.HttpContext.Request.Body.CanSeek)
                        actionContext.HttpContext.Request.Body.Position = 0;
                    
                    var result = await apiFormat.Formatter.ReadAsync(new InputFormatterContext(
                        actionContext.HttpContext,
                        string.Empty,
                        actionContext.ModelState,
                        metaDataProvider.GetMetadataForType(typeof(TRequestModel)),
                        (stream, encoding) => new StreamReader(stream, encoding)));
                    
                    if (actionContext.HttpContext.Request.Body.CanSeek)
                        actionContext.HttpContext.Request.Body.Position = streamPosition;
                    
                    return new IRequestModelBinder.Result<TRequestModel> {
                        Succeeded = result.IsModelSet && !result.HasError,
                        Model = (TRequestModel)result.Model,
                        ModelState = actionContext.ModelState
                    };
                }
                
                /////////////////Try binding any other source///////////////////
                actionContext.ModelState.Clear();
                var tryUpdateModelAsyncDelegate = TryUpdateModelAsyncGenericMethodDefinition.MakeGenericMethod(typeof(TRequestModel));
    
                var model = new TRequestModel();
                var bindingSucceeded = await (Task<bool>) tryUpdateModelAsyncDelegate.Invoke(null, new object?[] {
                    model,
                    prefix,
                    controllerContext,
                    metaDataProvider,
                    modelBinderFactory,
                    valueProvider,
                    objectValidator
                })!;
                
                return new IRequestModelBinder.Result<TRequestModel> {
                    Succeeded = bindingSucceeded,
                    Model = model,
                    ModelState = actionContext.ModelState
                };
            }
            catch (Exception e) {
                return new IRequestModelBinder.Result<TRequestModel> {
                    Succeeded = false,
                    Model = null,
                    ModelState = new ModelStateDictionary()
                };
            }
        }
    }