Search code examples
c#.netroslynsourcegenerators

Find all derived classes in a source generator


My source generator needs to find other classes that derive from the inspected class to know whether special code needs to be added. So far, I have only inspected the class to augment itself and followed references from there, like the type of class properties, but not any other classes that also exist in the project that uses the source generator. I can't seem to find any methods or interfaces or web documentation for this. Is this even possible and how would that work?

I'm looking for something like this:

public static IEnumerable<string> GetDerivedTypes(this ITypeSymbol typeSymbol)
{
    // TODO: Find all available classes,
    // then I can proceed with inheritance checks and further tests
    // Following is made-up code:
    var derivedTypeNames = typeSymbol.ContainingAssembly.AllTypes
        .Where(t => t.IsDerivedFrom(typeSymbol))
        .Select(t => t.Name);
    return derivedTypeNames;
}

Solution

  • I found out how this works. The answer from rotabor had a hint hidden in it somewhere. I could build the solution to the described problem around that line, after more try and error.

    public static string [] GetDerivedTypes(this ITypeSymbol typeSymbol)
    {
        var derivedTypes = GetAllTypes(typeSymbol.ContainingAssembly.GlobalNamespace)
            .Where(t => IsDerivedFrom(t, typeSymbol));
        return derivedTypes
            .Select(dt => dt.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat))
            .ToArray();
    
        static IEnumerable<INamedTypeSymbol> GetAllTypes(INamespaceSymbol nsSymbol) =>
            nsSymbol.GetTypeMembers()
                .Concat(nsSymbol.GetNamespaceMembers().SelectMany(GetAllTypes));
    
        static bool IsDerivedFrom(ITypeSymbol typeSymbol, ITypeSymbol baseTypeSymbol) =>
            typeSymbol.BaseType != null &&
            (typeSymbol.BaseType.Equals(baseTypeSymbol, SymbolEqualityComparer.Default) ||
                IsDerivedFrom(typeSymbol.BaseType, baseTypeSymbol));
    }
    

    The local helper function GetAllTypes recursively finds all types in all namespaces for the assembly. This will not find derived types in other assemblies, I can live with this limitation. Is it efficient? I can't tell.

    The other local helper function IsDerivedFrom then determines whether a type is derived from a base type, possibly transitively. This is also done in a recursive way. It is used to filter the set of all types in the assembly so that only derivations of the specified type remain.

    I took some inspiration from functional programming here and I think it fits nicely.