Search code examples
c#mysqlentity-framework-corepomelo-entityframeworkcore-mysql

How to insert custom codes in dbContext OnConfiguring generation?


I try to follow this answer Is there a way to scaffold mysql json into custom type? to make custom json type convert, and it works perfect!

The only thing what bother me is that I should modify Context code manual, to insert builder => builder.UseNewtonsoftJson().

I am wonderring if it could be in the generation process, it would be a life saver.

I am inspired by the answer which metioned above, and try to make it work.

What I want is

public partial class spckContext : DbContext
{
    ...
    protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
    {
        if (!optionsBuilder.IsConfigured)
        {
#warning To protect potentially sensitive information in your connection string, you should move it out of source code. You can avoid scaffolding the connection string by using the Name= syntax to read it from configuration - see https://go.microsoft.com/fwlink/?linkid=2131148. For more guidance on storing connection strings, see http://go.microsoft.com/fwlink/?LinkId=723263.
            optionsBuilder
                .UseMySql("server=localhost;port=3306;database=spck;user=root;password=;treattinyasboolean=true", Microsoft.EntityFrameworkCore.ServerVersion.Parse("8.0.29-mysql"), builder => builder .UseNewtonsoftJson())
                .EnableSensitiveDataLogging()
                .LogTo(Log, LogFilter, DbContextLoggerOptions.DefaultWithLocalTime); // <= stucked here, how to pass method as parameter?
        }
    }
    ...
}

I add these to my project:

using System.Drawing;
using Microsoft.Extensions.Logging;
using Console = Colorful.Console;

public partial class spckContext
{
    public static void Log(string content)
    {
        Console.WriteLineFormatted(content, Color.Aqua);
    }
    
    public static bool LogFilter(Microsoft.Extensions.Logging.EventId id, LogLevel level)
    {
        switch (level)
        {
            case LogLevel.Trace:
            case LogLevel.Debug:
            case LogLevel.Warning:
            case LogLevel.None:
                return false;
            case LogLevel.Error:
            case LogLevel.Critical:
            case LogLevel.Information:
                return true;
            default:
                return false;
        }
    }
}
public class MyDesignTimeServices : IDesignTimeServices
{
    public void ConfigureDesignTimeServices(IServiceCollection services)
    {
        ...
        //Type Mapping
        services.AddSingleton<IRelationalTypeMappingSource, CustomTypeMappingSource>();    // <= add this line

        //Option Generator
        services.AddSingleton<IProviderConfigurationCodeGenerator, ProviderConfigurationCodeGenerator>();    // <= and this line
        ...
    }
}
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Microsoft.Extensions.Logging;
using Pomelo.EntityFrameworkCore.MySql.Infrastructure.Internal;
using Pomelo.EntityFrameworkCore.MySql.Scaffolding.Internal;
using Pomelo.EntityFrameworkCore.MySql.Storage.Internal;

public class ProviderConfigurationCodeGenerator : MySqlCodeGenerator 
{
    private static readonly MethodInfo _enableSensitiveDataLoggingMethodInfo = typeof(DbContextOptionsBuilder).GetRequiredRuntimeMethod(
        nameof(DbContextOptionsBuilder.EnableSensitiveDataLogging),
        typeof(bool));
    
    private static readonly MethodInfo _useNewtonJsonMethodInfo = typeof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions).GetRequiredRuntimeMethod(
        nameof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions.UseNewtonsoftJson),
        typeof(MySqlDbContextOptionsBuilder),
        typeof(MySqlCommonJsonChangeTrackingOptions));
    
    private static readonly MethodInfo _logToMethodInfo = typeof(DbContextOptionsBuilder).GetRequiredRuntimeMethod(
        nameof(DbContextOptionsBuilder.LogTo),
        typeof(Action<string>),
        typeof(Func<EventId, LogLevel, bool>),
        typeof(DbContextLoggerOptions?));
    
    private static readonly MethodInfo _logMethodInfo = typeof(spckContext).GetRequiredRuntimeMethod(
        nameof(spckContext.Log),
        typeof(string));
    
    private static readonly MethodInfo _logFilterMethodInfo = typeof(spckContext).GetRequiredRuntimeMethod(
        nameof(spckContext.LogFilter),
        typeof(EventId),
        typeof(LogLevel));

    private readonly ProviderCodeGeneratorDependencies _dependencies;
    private readonly IMySqlOptions _options;
    
    public ProviderConfigurationCodeGenerator(ProviderCodeGeneratorDependencies dependencies, IMySqlOptions options) : base(dependencies, options)
    {
        _dependencies = dependencies;
        _options = options;
    }
    
    public override MethodCallCodeFragment GenerateUseProvider(string connectionString, MethodCallCodeFragment? providerOptions)
    {
        if (providerOptions == null)
        {
            providerOptions = new MethodCallCodeFragment(_useNewtonJsonMethodInfo);
        }
        else
        {
            providerOptions = providerOptions.Chain(new MethodCallCodeFragment(_useNewtonJsonMethodInfo));
        }
        var fragment = base.GenerateUseProvider(connectionString, providerOptions); //works
        fragment = fragment.Chain(_enableSensitiveDataLoggingMethodInfo); //works
        fragment = fragment.Chain(_logToMethodInfo, 
            new NestedClosureCodeFragment("str", new MethodCallCodeFragment(_logMethodInfo)), // <= try and failed! it convert into `str => str.Log()`
            new MethodCall(_logFilterMethodInfo), // <= try and failed! error reported
            DbContextLoggerOptions.DefaultWithLocalTime);

        return fragment;
    }
}

public static class TypeExtensions
{
    public static MethodInfo GetRequiredRuntimeMethod(this Type type, string name, params Type[] parameters)
        => type.GetTypeInfo().GetRuntimeMethod(name, parameters)
           ?? throw new InvalidOperationException($"Could not find method '{name}' on type '{type}'");
}
using Microsoft.EntityFrameworkCore.Storage;
using Pomelo.EntityFrameworkCore.MySql.Infrastructure.Internal;
using Pomelo.EntityFrameworkCore.MySql.Storage.Internal;

public class CustomTypeMappingSource : MySqlTypeMappingSource
{
    public CustomTypeMappingSource(TypeMappingSourceDependencies dependencies, RelationalTypeMappingSourceDependencies relationalDependencies, IMySqlOptions options) : base(dependencies, relationalDependencies, options)
    {
    }

    protected override RelationalTypeMapping FindMapping(in RelationalTypeMappingInfo mappingInfo)
    {
        if (mappingInfo.ClrType == typeof(MethodCall))
        {
            return new MethodCallTypeMapping();
        }

        return base.FindMapping(mappingInfo);
    }
}
using System.Linq.Expressions;
using System.Reflection;
using Microsoft.EntityFrameworkCore.Storage;

public class MethodCall
{
    public MethodInfo Method;

    public MethodCall(MethodInfo info)
    {
        Method = info;
    }
}

public class MethodCallTypeMapping : RelationalTypeMapping
{
    private const string DummyStoreType = "clrOnly";

    public MethodCallTypeMapping()
        : base(new RelationalTypeMappingParameters(new CoreTypeMappingParameters(typeof(MethodCall)), DummyStoreType))
    {
    }

    protected MethodCallTypeMapping(RelationalTypeMappingParameters parameters)
        : base(parameters)
    {
    }

    protected override RelationalTypeMapping Clone(RelationalTypeMappingParameters parameters)
        => new MethodCallTypeMapping(parameters);

    public override string GenerateSqlLiteral(object value)
        => throw new InvalidOperationException("This type mapping exists for code generation only.");

    public override Expression GenerateCodeLiteral(object value)
    {
        return value is MethodCall methodCall
            ? Expression.Call(methodCall.Method) // <= not working, how to fix this?
            : null;
    }
}

So my question is how to make a MethodCallCodeFragment with method parameter? I tried google, but can't find anything valuable. And MSDN has no sample code for this feature.


Solution

  • Injecting the .UseNewtonsoftJson() and .EnableSensitiveDataLogging() calls can simply be done by providing the design time services with your own IProviderCodeGeneratorPlugin implementation:

    public class MyDesignTimeServices : IDesignTimeServices
    {
        public void ConfigureDesignTimeServices(IServiceCollection services)
        {
            services.AddSingleton<IProviderCodeGeneratorPlugin, CustomProviderCodeGeneratorPlugin>();
            services.AddEntityFrameworkMySqlJsonNewtonsoft();
        }
    }
    
    public class CustomProviderCodeGeneratorPlugin : IProviderCodeGeneratorPlugin
    {
        private static readonly MethodInfo EnableSensitiveDataLoggingMethodInfo = typeof(DbContextOptionsBuilder).GetRequiredRuntimeMethod(
            nameof(DbContextOptionsBuilder.EnableSensitiveDataLogging),
            typeof(bool));
    
        private static readonly MethodInfo UseNewtonJsonMethodInfo = typeof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions).GetRequiredRuntimeMethod(
            nameof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions.UseNewtonsoftJson),
            typeof(MySqlDbContextOptionsBuilder),
            typeof(MySqlCommonJsonChangeTrackingOptions));
    
        public MethodCallCodeFragment GenerateProviderOptions()
            => new MethodCallCodeFragment(UseNewtonJsonMethodInfo);
    
        public MethodCallCodeFragment GenerateContextOptions()
            => new MethodCallCodeFragment(EnableSensitiveDataLoggingMethodInfo);
    }
    

    Implementing the complex .LogTo(Log, LogFilter, DbContextLoggerOptions.DefaultWithLocalTime) call is not as straitforward, because the translation logic of EF Core for translating a code generation expression tree to C# code is very basic at best.

    Implementing a dummy type mapping to return a complex expression will not work in the end, because EF Core will not be able to translate the lambda expressions of content => LogTo(content) and (id, level) => LogFilter(id, level). You could try to trick it, but the simplest solution is to just circumvent the whole expression translation mechanism.

    To output any string as C# code, just override ICSharpHelper.UnknownLiteral(object value) in your own implementation.

    Here is a fully working example:

    using System;
    using System.Diagnostics;
    using System.Reflection;
    using Microsoft.EntityFrameworkCore;
    using Microsoft.EntityFrameworkCore.Design;
    using Microsoft.EntityFrameworkCore.Design.Internal;
    using Microsoft.EntityFrameworkCore.Diagnostics;
    using Microsoft.EntityFrameworkCore.Infrastructure;
    using Microsoft.EntityFrameworkCore.Scaffolding;
    using Microsoft.EntityFrameworkCore.Storage;
    using Microsoft.Extensions.DependencyInjection;
    using Microsoft.Extensions.Logging;
    
    namespace IssueConsoleTemplate;
    
    public class MyDesignTimeServices : IDesignTimeServices
    {
        public void ConfigureDesignTimeServices(IServiceCollection services)
        {
            services.AddSingleton<IProviderCodeGeneratorPlugin, CustomProviderCodeGeneratorPlugin>();
            services.AddSingleton<ICSharpHelper, CustomCSharpHelper>();
            services.AddEntityFrameworkMySqlJsonNewtonsoft();
        }
    }
    
    public static class TypeExtensions
    {
        public static MethodInfo GetRequiredRuntimeMethod(this Type type, string name, params Type[] parameters)
            => type.GetTypeInfo().GetRuntimeMethod(name, parameters)
               ?? throw new InvalidOperationException($"Could not find method '{name}' on type '{type}'");
    }
    
    public class CustomProviderCodeGeneratorPlugin : IProviderCodeGeneratorPlugin
    {
        private static readonly MethodInfo EnableSensitiveDataLoggingMethodInfo = typeof(DbContextOptionsBuilder).GetRequiredRuntimeMethod(
            nameof(DbContextOptionsBuilder.EnableSensitiveDataLogging),
            typeof(bool));
    
        private static readonly MethodInfo UseNewtonJsonMethodInfo = typeof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions).GetRequiredRuntimeMethod(
            nameof(MySqlJsonNewtonsoftDbContextOptionsBuilderExtensions.UseNewtonsoftJson),
            typeof(MySqlDbContextOptionsBuilder),
            typeof(MySqlCommonJsonChangeTrackingOptions));
        
        private static readonly MethodInfo LogToMethodInfo = typeof(DbContextOptionsBuilder).GetRequiredRuntimeMethod(
            nameof(DbContextOptionsBuilder.LogTo),
            typeof(Action<string>),
            typeof(Func<EventId, LogLevel, bool>),
            typeof(DbContextLoggerOptions?));
    
        public MethodCallCodeFragment GenerateProviderOptions()
            => new MethodCallCodeFragment(UseNewtonJsonMethodInfo);
    
        public MethodCallCodeFragment GenerateContextOptions()
            => new MethodCallCodeFragment(EnableSensitiveDataLoggingMethodInfo)
                .Chain(GenerateLogToMethodCallCodeFragment());
    
        private MethodCallCodeFragment GenerateLogToMethodCallCodeFragment()
            => new MethodCallCodeFragment(
                LogToMethodInfo,
                new CSharpCodeGenerationExpressionString("Log"),
                new CSharpCodeGenerationExpressionString("LogFilter"),
                new CSharpCodeGenerationExpressionString("Microsoft.EntityFrameworkCore.Diagnostics.DbContextLoggerOptions.DefaultWithLocalTime"));
    }
    
    public class CSharpCodeGenerationExpressionString
    {
        public string ExpressionString { get; }
    
        public CSharpCodeGenerationExpressionString(string expressionString)
            => ExpressionString = expressionString;
    }
    
    public class CustomCSharpHelper : CSharpHelper
    {
        public CustomCSharpHelper(ITypeMappingSource typeMappingSource)
            : base(typeMappingSource)
        {
        }
    
        public override string UnknownLiteral(object value)
            => value is CSharpCodeGenerationExpressionString codeGenerationExpressionString
                ? codeGenerationExpressionString.ExpressionString
                : base.UnknownLiteral(value);
    }
    
    public partial class Context
    {
        public static void Log(string content)
            => Console.Write(content);
    
        public static bool LogFilter(EventId id, LogLevel level)
            => level >= LogLevel.Information;
    }
    
    internal static class Program
    {
        private static void Main()
        {
        }
    }
    

    We basically just create our own type called CSharpCodeGenerationExpressionString to hold the C# code string that we want to output and then tell the CustomCSharpHelper.UnknownLiteral() method to return it as is.

    The generated OnConfiguring() method looks like this:

    public partial class Context : DbContext
    {
        // ...
    
        protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder)
        {
            if (!optionsBuilder.IsConfigured)
            {
    #warning To protect potentially sensitive information in your connection string, you should move it out of source code. You can avoid scaffolding the connection string by using the Name= syntax to read it from configuration - see https://go.microsoft.com/fwlink/?linkid=2131148. For more guidance on storing connection strings, see http://go.microsoft.com/fwlink/?LinkId=723263.
                optionsBuilder
                    .UseMySql("server=127.0.0.1;port=3306;user=root;database=So73163124_01", Microsoft.EntityFrameworkCore.ServerVersion.Parse("8.0.29-mysql"), x => x.UseNewtonsoftJson())
                    .EnableSensitiveDataLogging()
                    .LogTo(Log, LogFilter, Microsoft.EntityFrameworkCore.Diagnostics.DbContextLoggerOptions.DefaultWithLocalTime);
            }
        }
    
        // ...
    }