Search code examples
entity-framework-coregraphqldataloaderhotchocolate

HotChocolate v.13 [UseProjections] attribute does not work with DataLoaders


I have the following GrapqhQL query:

query {
  listTenants {
    totalCount
    items {
      tenantId
      name
      sites {
        totalCount
        items {
          siteId
          cmxName
          cmxState
          hosts(
            order: { hostId: ASC }
            where: { hostName: { neq: "ans" } }
            skip: 4
            take: 2
          ) {
            totalCount
            items {
              hostId
              hostName
              siteId
            }
          }
        }
      }
    }
  }
}

I want to use projections for the host object - I want to extract from the database only the hostId, hostName and siteId for each host. I extend the Site object type as follows:

namespace dataGraphAPI.Types.Nodes
{
    [ExtendObjectType<Site>]
    public static class SiteNode
    {
        [GraphQLName(SchemaConstants.Hosts)]
        [ListQueries]
        public static async Task<IEnumerable<Host>> GetHostsAsync(
            [Parent] Site site,
            IHostsBySiteIdDataLoader dataLoader,
            CancellationToken ct)
            =>  await dataLoader.LoadAsync(site.SiteId.ToString(), ct);
    }
}

My [ListQueries] attribute contains the following attributes:

using HotChocolate.Types.Descriptors;
using System.Reflection;

namespace dataGraphAPI.Types
{
    public sealed class ListQueriesAttribute : ObjectFieldDescriptorAttribute
    {
        protected override void OnConfigure(IDescriptorContext context, IObjectFieldDescriptor descriptor, MemberInfo member)
        {
            ApplyAttribute(
                context, 
                descriptor, 
                member, 
                new UseOffsetPagingAttribute() 
                { 
                    IncludeTotalCount = true,
                });

            ApplyAttribute(
                context,
                descriptor,
                member,
                new UseProjectionAttribute());

            ApplyAttribute(
                context,
                descriptor,
                member,
                new UseFilteringAttribute());

            ApplyAttribute(
                context,
                descriptor,
                member,
                new UseSortingAttribute());
        }
    }
}

My data loader is as follows:

[DataLoader]
internal static async Task<ILookup<string, Host>> GetHostsBySiteIdAsync(IReadOnlyList<string> siteIds, CmxDbContext dbContext, CancellationToken ct)
        {
            var hosts = dbContext.Hosts
                .Where(x => siteIds.Contains(x.SiteId.ToString()));

            return hosts.ToLookup(x => x.SiteId.ToString()!);
        }

My Program.cs is as follows:

var builder = WebApplication.CreateBuilder(args);

builder.Services.AddDbContext<CmxDbContext>(options =>
    options.UseNpgsql(builder.Configuration.GetConnectionString("CMXContext"))
           .UseQueryTrackingBehavior(QueryTrackingBehavior.NoTracking)
           .LogTo(Console.WriteLine, LogLevel.Information));

builder.Services
                .AddGraphQLServer()
                .AddTypes()
                .AddType<AggregateResult>()
                .AddType<CountResult>()
                .AddType<DistinctResult>()
                .AddDirectiveType<AggregateDirectiveType>()
                .AddDirectiveType<CountDirectiveType>()
                .AddDirectiveType<DistinctDirectiveType>()
                .AddFiltering()
                .AddSorting()
                .AddProjections()
                .RegisterDbContext<CmxDbContext>();

var app = builder.Build();

app.MapGraphQL();

app.Run();

The problem is that when I extract the hosts from the database the UseProjectionAttribute which is set in the ListQueriesAttribute does not work and I retrieve all the columns for each host from the database, not only the hostId, hostName and siteId. I am kind of new to data loaders and HotChocolate, so I may be doing something wrong. I understand that [UseProjection] works for IQueryable, but it seems to me that there is no way to return IQueryable with data loaders, hence I cannot apply the attribute to the data loader. Any suggestions how I can make this work with projections?


Solution

  • So, I think I understand now what's going on and I found a solution for the problem:

    1. The HotChocolate middlewares for projections, filtering and sorting are all valid only for IQueryable and will not work for a layered design or for IEnumerable.
    2. Data loaders always require from you to return a specific type of result: grouped data loaders must always return Task<ILookup<TKey, TValue>> and batch data loaders must always return Task<IReadOnlyDictionary<TKey, TValue>>. But if you want to apply a certain projections, filterings and sortings, you must do it before materialising the IQueryable into an IEnumerable in order to do all these operations inside the DB.
    3. The dataLoader.LoadAsync() method does not accept IResolverContext as an argument and there is no way to access it from inside to gather the selection set and the filtering and sorting clauses. So you need to pass it somehow to the data loader with the key in order to enrich the request from there.

    So here is how I did it. I know it's not a perfect solution, but it's the only one I've got at the moment. I defined a record as follows:

    public record Request
    {
        public Guid? Id { get; set; }
    
        public IResolverContext ResolverContext { get; set; } = null!;
    
        public override int GetHashCode()
        {
            return HashCode.Combine(Id);
        }
    }
    

    I pass the Request record as a key to the dataLoader.LoadAsync(key, cancellationToken) method:

        [GraphQLName(SchemaConstants.Hosts)]
        [ListQueries]
        public static async Task<IEnumerable<Host>> GetHostsAsync(
            [Parent] Site site,
            IResolverContext resolverContext,
            IHostsBySiteIdDataLoader dataLoader,
            CancellationToken ct)
        {
            var key = new Request { Id = site.SiteId, ResolverContext = resolverContext };
            var result = await dataLoader.LoadAsync(key, ct);
    
            return result;
        }
    

    Inside the data loader I enrich the requests with a mediator method by extracting from the resolver context the selection set, the filtration clause and the sorting clause. Since the selections, filtering and sorting are the same per each GraphQL object in the GraphQL query, it is enough to get them only from the first Request key. Then I pass them as arguments to my GetRequestedEntities(...) method where I apply all the necessary projections, filtering and sorting:

        [DataLoader]
        internal static async Task<ILookup<Request, Host>> GetHostsBySiteIdAsync(IReadOnlyList<Request> keys, CMXDbContext dbContext, CancellationToken ct)
        {
            var parentIds = keys.Select(x => x.Id).ToList();
            var requestProps = ResolverHelpers.EnrichRequest(keys[0]);
            var entities = dbContext.Hosts.Where(x => parentIds!.Contains(x.SiteId)).AsQueryable();
            var list = await ResolverHelpers.GetRequestedEntitiesAsync<Host>(entities, requestProps.Item1, requestProps.Item2, requestProps.Item3, ct);
            var result = list.ToLookup(x => keys.Single(k => k.Id == x.SiteId));
    
            return result;
        }
    

    In order to make it more generic I use Dynamic LINQ. Note also that in order to extract the selection set the way I do, you should use the HotChocolate.PreprocessingExtensions.

    using dataGraphAPI.Common;
    using HotChocolate.Language;
    using HotChocolate.PreProcessingExtensions.Selections;
    using Microsoft.EntityFrameworkCore;
    using ServiceStack;
    using System.Linq.Dynamic.Core;
    
    namespace dataGraphAPI.Types
    {
        public static class ResolverHelpers
        {
            public static async Task<IEnumerable<T>> GetRequestedEntitiesAsync<T>(
                IQueryable entities, 
                string selections, 
                string? filtration, 
                string? sorting, 
                CancellationToken cancellationToken) where T : class
            {
                if (filtration != null)
                {
                    entities = entities.Where(filtration);
                }
    
                if (sorting != null)
                {
                    entities = entities.OrderBy(sorting);
                }
    
                var result = await entities
                    .Select<T>($"new {{{selections}}}")
                    .ToListAsync(cancellationToken);
    
                return result;
            }
    
            public static (string, string?, string?) EnrichRequest(Request key)
            {
                var selections = GetSelections(key);
                var filtration = GetFilteringClause(key);
                var sorting = GetSortingClause(key);
    
                return (selections, filtration, sorting);
            }
    
            private static string GetSelections(Request key)
            {
                var parent = key.ResolverContext.Parent<Object>().GetType().Name;
                var parentId = $"{parent}Id";
    
                var selections = $"{parentId}, {string.Join(", ", key.ResolverContext.GetPreProcessingSelections()!
                    .Select(s => s.SelectionName)
                    .Distinct(StringComparer.OrdinalIgnoreCase))}";
    
                return selections;
            }
    
            private static string? GetSortingClause(Request key)
            {
                var sortings = new List<string>();
                var orderClause = key.ResolverContext.ArgumentLiteral<IValueNode>(SchemaConstants.Order).Value as IEnumerable<ObjectFieldNode>;
    
                if (orderClause != null)
                {
                    foreach (var order in orderClause)
                    {
                        sortings.Add($"{order.Name} {order.Value}");
                    }
    
                    var orderLinq = string.Join(',', sortings);
    
                    return orderLinq;
                }
                else
                {
                    return null;
                }
            }
    
            private static string? GetFilteringClause(Request key)
            {
                var filterings = new List<string>();
                var whereClause = key.ResolverContext.ArgumentLiteral<IValueNode>(SchemaConstants.Where).Value as IEnumerable<ObjectFieldNode>;
    
                if (whereClause != null)
                {
                    foreach (var filter in whereClause)
                    {
                        var filtratingField = filter.Name.ToString();
                        var input = filter.Value as ObjectValueNode;
    
                        foreach (var field in input!.Fields)
                        {
                            var name = field.Name.ToString();
                            var value = field.Value.ToString();
    
                            var fieldName = name switch
                            {
                                "eq" => $"{filtratingField}=={value}",
                                "neq" => $"{filtratingField}!={value}",
                                "gt" => $"{filtratingField}>{value}",
                                "gte" => $"{filtratingField}>={value}",
                                "lt" => $"{filtratingField}<{value}",
                                "lte" => $"{filtratingField}<={value}",
                                "in" => $"{filtratingField}.Contains({value})",
                                "nin" => $"!{filtratingField}.Contains({value})",
                                "startsWith" => $"{filtratingField}.StartsWith({value})",
                                "nstartsWith" => $"!{filtratingField}.StartsWith({value})",
                                _ => throw new NotSupportedException()
                            };
    
                            filterings.Add(fieldName);
                        }
                    }
    
                    var filteringsLinq = string.Join(',', filterings);
    
                    return filteringsLinq;
                }
                else
                {
                    return null;
                }
            }
        }
    }
    

    If anyone has better suggestions (which there probably are, since I am a junior developer at the moment), I would like to see them, but that's what I have for now.