Search code examples
c#attributesextension-methodspostsharp

Can you restrict the scope of an Extension Method to classes with a specific attribute?


We have a custom FileExtensionAttribute which we decorate our model classes which are based on file persistence with. It is defined as follows:

[AttributeUsage(AttributeTargets.Class, AllowMultiple=true, Inherited=true)]
public class FileExtensionAttribute : Attribute
{
    public FileExtensionAttribute(string fileExtension)
    {
        FileExtension = fileExtension;
    }

    public readonly string FileExtension;
}

We've also created the following extension methods to make retrieving those extensions more convenient:

public static class FileExtensionAttributeHelper
{
    public static IEnumerable<string> GetFileExtensions(this Type type)
    {
        return type.CustomAttributes
            .OfType<FileExtensionAttribute>()
            .Select(fileExtensionAttribute => fileExtensionAttribute.FileExtension);
    }

    public static string GetPrimaryFileExtension(this Type type)
    {
        return GetFileExtensions(type).FirstOrDefault();
    }
}

In the above, for types which don't have the attribute specified, the two methods return an empty enumeration or null respectively. However, we would like to be more proactive in stopping such calls in the first place.

While we can easily throw an exception if no such attributes are found on the specified type, I'm wondering if there's a way to restrict the calling of the extension methods to only support types which have that attribute set in the first place so it's a compile-time error and not something that has to be dealt with at run-time.

So is it possible to restrict extension methods to only support types with a given attribute? If so, how?

Note: I'm thinking this may not be possible in pure C#, but perhaps something like PostSharp can be used for this.


Solution

  • PostSharp can indeed help you.

    Outline:

    • Create AssemblyLevelAspect that would search using ReflectionSearch for all uses of your extension methods in the assembly. This will give a list of methods that call those extension methods.
    • For all these methods, get the syntax tree using ISyntaxReflectionService. It is IL syntax tree not the source code itself.
    • Search for patterns like typeof(X).GetFileExtensions() and variable.GetType.GetFileExtensions() and validate that the passed type has FileExtension attribute.
    • Write a compile time error if incorrect usage is found.

    Source:

    [MulticastAttributeUsage(PersistMetaData = true)]
    public class FileExtensionValidationPolicy : AssemblyLevelAspect
    {
        public override bool CompileTimeValidate( Assembly assembly )
        {
            ISyntaxReflectionService reflectionService = PostSharpEnvironment.CurrentProject.GetService<ISyntaxReflectionService>();
    
            MethodInfo[] validatedMethods = new[]
            {
                typeof(FileExtensionAttributeHelper).GetMethod( "GetFileExtensions", BindingFlags.Public | BindingFlags.Static ),
                typeof(FileExtensionAttributeHelper).GetMethod( "GetPrimaryFileExtension", BindingFlags.Public | BindingFlags.Static )
            };
    
            MethodBase[] referencingMethods =
                validatedMethods
                    .SelectMany( ReflectionSearch.GetMethodsUsingDeclaration )
                    .Select( r => r.UsingMethod )
                    .Where( m => !validatedMethods.Contains( m ) )
                    .Distinct()
                    .ToArray();
    
            foreach ( MethodBase userMethod in referencingMethods )
            {
                ISyntaxMethodBody body = reflectionService.GetMethodBody( userMethod, SyntaxAbstractionLevel.ExpressionTree );
    
                ValidateMethodBody(body, userMethod, validatedMethods);
            }
    
            return false;
        }
    
        private void ValidateMethodBody(ISyntaxMethodBody methodBody, MethodBase userMethod, MethodInfo[] validatedMethods)
        {
            MethodBodyValidator validator = new MethodBodyValidator(userMethod, validatedMethods);
    
            validator.VisitMethodBody(methodBody);
        }
    
        private class MethodBodyValidator : SyntaxTreeVisitor
        {
            private MethodBase userMethod;
            private MethodInfo[] validatedMethods;
    
            public MethodBodyValidator( MethodBase userMethod, MethodInfo[] validatedMethods )
            {
                this.userMethod = userMethod;
                this.validatedMethods = validatedMethods;
            }
    
            public override object VisitMethodCallExpression( IMethodCallExpression expression )
            {
                foreach ( MethodInfo validatedMethod in this.validatedMethods )
                {
                    if ( validatedMethod != expression.Method )
                        continue;
    
                    this.ValidateTypeOfExpression(validatedMethod, expression.Arguments[0]);
                    this.ValidateGetTypeExpression(validatedMethod, expression.Arguments[0]);
                }
    
                return base.VisitMethodCallExpression( expression );
            }
    
            private void ValidateTypeOfExpression(MethodInfo validatedMethod, IExpression expression)
            {
                IMethodCallExpression callExpression = expression as IMethodCallExpression;
    
                if (callExpression == null)
                    return;
    
                if (callExpression.Method != typeof(Type).GetMethod("GetTypeFromHandle"))
                    return;
    
                IMetadataExpression metadataExpression = callExpression.Arguments[0] as IMetadataExpression;
    
                if (metadataExpression == null)
                    return;
    
                Type type = metadataExpression.Declaration as Type;
    
                if (type == null)
                    return;
    
                if (!type.GetCustomAttributes(typeof(FileExtensionAttribute)).Any())
                {
                    MessageSource.MessageSink.Write(
                        new Message(
                            MessageLocation.Of( this.userMethod ),
                            SeverityType.Error, "MYERR1",
                            String.Format( "Calling method {0} on type {1} is not allowed.", validatedMethod, type ),
                            null, null, null
                            )
                        );
                }
            }
    
            private void ValidateGetTypeExpression(MethodInfo validatedMethod, IExpression expression)
            {
                IMethodCallExpression callExpression = expression as IMethodCallExpression;
    
                if (callExpression == null)
                    return;
    
                if (callExpression.Method != typeof(object).GetMethod("GetType"))
                    return;
    
                IExpression instanceExpression = callExpression.Instance;
    
                Type type = instanceExpression.ReturnType;
    
                if (type == null)
                    return;
    
                if (!type.GetCustomAttributes(typeof(FileExtensionAttribute)).Any())
                {
                    MessageSource.MessageSink.Write(
                        new Message(
                            MessageLocation.Of(this.userMethod),
                            SeverityType.Error, "MYERR1",
                            String.Format("Calling method {0} on type {1} is not allowed.", validatedMethod, type),
                            null, null, null
                            )
                        );
                }
            }
        }
    }
    

    Usage:

    [assembly: FileExtensionValidationPolicy(
                   AttributeInheritance = MulticastInheritance.Multicast
                   )]
    

    Notes:

    • [MulticastAttributeUsage(PersistMetaData = true)] and AttributeInheritance = MulticastInheritance.Multicast are both needed to preserve the attribute on the assembly so that the analysis is performed also on projects that reference the declaring project.
    • More deep analysis may be needed to correctly handle derived classes and other special cases.
    • PostSharp Professional license is needed.