Search code examples
c#roslyn-code-analysissourcegenerators

Source generator - method of getting all implementations of a symbol (including dependent assemblies)?


I have an Incremental Source Generator and need to find all implementations of a given symbol (interface for the time being). Through some googling I found this question which suggested Microsoft.CodeAnalysis.FindSymbols.SymbolFinder.FindImplementedInterfaceMembers which does exactly what I want but it's for stand alone Roslyn without the code generator part which means that I don't have the structures needed to fulfil the parameters.

Is there a way of getting all implementations of an interface (or arbitrary symbol) using the source generator infrastructure?

The main issue I've found is that I also want to know about the implementations in dependencies which means I need to pull data out of the Compilation class too.


Solution

  • This is a functional solution that could be optimised with a little more work

    public static ImmutableArray<INamedTypeSymbol> FindImplementations<T>(Compilation compilation)
        where T : class
    {
        // This chunk of code is pretty inefficient and could be greatly improved!!
        INamedTypeSymbol[] allTypes = GetAllNamespaces(compilation)
            .SelectMany(t => t.GetTypeMembers())
            .SelectMany(AllNestedTypesAndSelf)
            .ToArray();
    
        string[] targetNamespace = typeof(T).Namespace.Split('.');
        string targetName = typeof(T).Name;
        INamedTypeSymbol? desiredType = allTypes
            .FirstOrDefault(p => MatchesBaseType(p, targetNamespace, targetName));
    
        if (desiredType == null)
        {
            // we wouldn't find the type anyway
            return ImmutableArray<INamedTypeSymbol>.Empty;
        }
    
        // search for the applicable types
        if (typeof(T).IsInterface)
        {
             return allTypes
                .Where(t => t.AllInterfaces.Contains(desiredType, SymbolEqualityComparer.Default))
                .ToImmutableArray();
        }
        else if (typeof(T).IsClass)
        {
            return allTypes
                .Where(MatchesBaseType)
                .ToImmutableArray();
        }
    
        // theoretically impossible because of T restrictions
        throw new InvalidOperationException("Unexpected implementation result!");
    }
    
    private static IEnumerable<INamedTypeSymbol> AllNestedTypesAndSelf(this INamedTypeSymbol type)
    {
        yield return type;
        foreach (var typeMember in type.GetTypeMembers())
        {
            foreach (var nestedType in typeMember.AllNestedTypesAndSelf())
            {
                    yield return nestedType;
            }
        }
    }
    
    private static ImmutableArray<INamespaceSymbol> GetAllNamespaces(Compilation compilation)
    {
        HashSet<INamespaceSymbol> seen = new HashSet<INamespaceSymbol>(SymbolEqualityComparer.Default);
        Queue<INamespaceSymbol> visit = new Queue<INamespaceSymbol>();
        visit.Enqueue(compilation.GlobalNamespace);
    
        do
        {
            INamespaceSymbol search = visit.Dequeue();
            seen.Add(search);
    
            foreach (INamespaceSymbol? space in search.GetNamespaceMembers())
            {
                if (space == null || seen.Contains(space))
                {
                    continue;
                }
    
                visit.Enqueue(space);
            }
        } while (visit.Count > 0);
    
        return seen.ToImmutableArray();
    }
    
    private static bool MatchesBaseType(INamedTypeSymbol symbol)
    {
        return symbol.BaseType != null && 
            (SymbolEqualityComparer.Default.Equals(symbol.BaseType, desiredType) 
            || checkBaseType(symbol.BaseType));
    }
    
    private static bool IsTypeMatch(INamedTypeSymbol symbol, string[] searchNamespace, string searchName)
    {
        if (symbol.Name != searchName)
        {
            return false;
        }
    
        INamespaceSymbol? currentNamespace = symbol.ContainingNamespace;
        for (int i = searchNamespace.Length - 1; i >= 0; i--)
        {
            if (searchNamespace[i] != currentNamespace?.Name)
            {
                return false;
            }
    
            currentNamespace = currentNamespace.ContainingNamespace;
        }
    
        // this should be the global namespace to indicate that we have
        // reached the root of the namespace
        return currentNamespace?.IsGlobalNamespace ?? false;
    }
    

    Optimization that can be done later

    We know which assembly T lives in and we know what we are compiling. so we should be able to throw away all assemblies that are used in both compiled and T-assembly because everything used by both can't know about T as circular references in assemblies are not allowed, this leaves all assemblies that could potentially have an implementation.

    This would then reduce the total amount of types that have to be checked greatly, making this code way faster but I currently haven't done this optimization.