Search code examples
c++clangc++20c++-concepts

Using Clang library, how to check if a class is a valid match for a concept


I have a small code that parser C++ declarations as class, template class, functions, and. concepts, and then I would like to write the name of the class, and the name of the concept, and this code says if this class is valid for the requirements of the concept or not, and if it isn't, tell which requirements was not satisfied.

This is the function that I tried to implement but it's not working, it always return false, even when the class meets the concept requirements. So I'd like to find some example of how to analise concepts with Clang Sema:

void CheckConceptUsage(Sema &SemaRef, const ConceptDecl *Concept, const CXXRecordDecl *Class) {
    // Get the type of the class.
    QualType ClassType = SemaRef.Context.getRecordType(Class);

    // Create a TemplateArgument representing the class type.
    TemplateArgument ClassTemplateArg(ClassType);

    // Prepare the necessary data structures for the constraint check.
    ConstraintSatisfaction Satisfaction;

    // Create a MultiLevelTemplateArgumentList
    MultiLevelTemplateArgumentList TemplateArgs;
     ArrayRef<TemplateArgument> TemplateArgsRef(ClassTemplateArg);
    TemplateArgs.addOuterTemplateArguments(const_cast<CXXRecordDecl *>(Class), TemplateArgsRef, /*Final*/ true);
    // TemplateArgs.addOuterTemplateArguments(ArrayRef<TemplateArgument>(ClassTemplateArg));


    // Retrieve the constraint expression associated with the concept
    const Expr *ConstraintExpr = Concept->getConstraintExpr();
    
    if (!ConstraintExpr) {
        llvm::outs() << "The concept " << Concept->getNameAsString() 
                     << " has no constraints (requires clause) to check.\n";
        return;
    }

    // Cast the constraint expression to RequiresExpr to access its components
    if (const RequiresExpr *ReqExpr = llvm::dyn_cast<RequiresExpr>(ConstraintExpr)) {
        std::cout << "--- CheckConceptUsage if " << std::endl;
        // Get the list of requirements (constraints) in the requires expression
        llvm::SmallVector<const Expr*, 4> ConstraintExprs;

        for (const auto &Requirement : ReqExpr->getRequirements()) {
            std::cout << "--- CheckConceptUsage for " << std::endl;
            if (const auto *ExprReq = llvm::dyn_cast<clang::concepts::ExprRequirement>(Requirement)) {
                // Handle expression requirements
                std::cout << "--- CheckConceptUsage ExprRequirement" << std::endl;
                ConstraintExprs.push_back(ExprReq->getExpr());
            } else if (const auto *TypeReq = llvm::dyn_cast<clang::concepts::TypeRequirement>(Requirement)) {
                // Handle type requirements by evaluating the type's instantiation dependency
                std::cout << "--- CheckConceptUsage TypeRequirement" << std::endl;
                QualType Type = TypeReq->getType()->getType();
                QualType DependentType = TypeReq->getType()->getType();
                if (Type->isDependentType()) {
                    std::cout << "--- CheckConceptUsage isDependentType" << std::endl;
                    // Create a pseudo-expression that checks if this type exists
                    // TypeTraitExpr *TraitExpr = TypeTraitExpr::Create(
                    //     SemaRef.Context, 
                    //     DependentType,
                    //     SourceLocation(), 
                    //     UTT_IsCompleteType,  // Use a type trait like "is complete type"
                    //     ArrayRef<QualType>(DependentType),
                    //     SourceLocation(), 
                    //     SemaRef.Context.BoolTy
                    // );
                    
                    // ConstraintExprs.push_back(TraitExpr);
                }
            }
        }

        
        std::cout << "--- CheckConceptUsage ConstraintExprs size:" << ConstraintExprs.size() << std::endl;

        // Now use the updated list of constraints in the satisfaction check
        bool IsSatisfied = SemaRef.CheckConstraintSatisfaction(
            Concept, 
            ConstraintExprs, 
            TemplateArgs, 
            Class->getSourceRange(), 
            Satisfaction
        );

        if (IsSatisfied) {
            llvm::outs() << "The class " << Class->getName() << " satisfies the concept " << Concept->getName() << ".\n";
        } else {
            llvm::outs() << "The class " << Class->getName() << " does NOT satisfy the concept " << Concept->getName() << ".\n";
        }
    } else {
        llvm::outs() << "The concept " << Concept->getNameAsString() 
                     << " does not have a valid requires expression.\n";
    }
}

And this is my minimal reproducible example.

https://gist.github.com/alexst07/7dadf36ea663171e91778a77d01fbbda


Solution

  • The original code in the gist is very close to working. I was able to make it work on the given example, and another simple one, with a few changes.

    Problem: CheckConstraintSatisfaction return value

    Sema::CheckConstraintSatisfaction returns:

    true if an error occurred and satisfaction could not be checked, false otherwise.

    That is, it indicates if the satisfaction check could be performed, not whether the constraints were satisfied.

    Consequently, the line:

            bool IsSatisfied = SemaRef.CheckConstraintSatisfaction(
    

    should instead be something like:

            bool HadError = SemaRef.CheckConstraintSatisfaction(
                ...
            );
    
            if (HadError) {
                llvm::outs() << "CheckConstraintSatisfaction reported an error.\n";
                return;
            }
    

    I'm not sure what exactly constitutes an error for this call. In my tests, HadError was always false.

    Then, to check for satisfaction, look at the IsSatisfied flag of the Satisfaction object:

            if (Satisfaction.IsSatisfied) {
                llvm::outs() << "The class " << Class->getName() << " satisfies the concept " << Concept->getName() << ".\n";
            } else {
                llvm::outs() << "The class " << Class->getName() << " does NOT satisfy the concept " << Concept->getName() << ".\n";
            }
    

    You say in a comment you tried that; the other issues probably masked this fix.

    Problem: Constructing a TypeTraitExpr

    In the case of a TypeRequirement, the original code had a commented-out attempt to create a TypeTraitExpr that could then be passed to CheckConstraintSatisfaction:

                    if (Type->isDependentType()) {
                        std::cout << "--- CheckConceptUsage isDependentType" << std::endl;
                        // Create a pseudo-expression that checks if this type exists
                        // TypeTraitExpr *TraitExpr = TypeTraitExpr::Create(
                        //     SemaRef.Context, 
                        //     DependentType,
                        //     SourceLocation(), 
                        //     UTT_IsCompleteType,  // Use a type trait like "is complete type"
                        //     ArrayRef<QualType>(DependentType),
                        //     SourceLocation(), 
                        //     SemaRef.Context.BoolTy
                        // );
                        
                        // ConstraintExprs.push_back(TraitExpr);
                    }
    

    Since this was commented out, and the example uses two type requirements, the IsSatisfied flag was always true because no constraints were passed to it.

    After uncommenting the code, it does not compile, but a few small changes are enough to fix it:

                    if (Type->isDependentType()) {
                        std::cout << "--- CheckConceptUsage isDependentType" << std::endl;
                        // Create a pseudo-expression that checks if this type exists
                        TypeTraitExpr *TraitExpr = TypeTraitExpr::Create(
                            SemaRef.Context, 
                            DependentType,
                            SourceLocation(), 
                            UTT_IsCompleteType,  // Use a type trait like "is complete type"
                            ArrayRef<TypeSourceInfo*>(TypeReq->getType()),
                            SourceLocation(), 
                            false /* Value; meaning is unclear */
                        );
                        
                        ConstraintExprs.push_back(TraitExpr);
                    }
    

    Specifically:

    • We need to pass an array of TypeSourceInfo, not QualType. The former have source location information and other details tied to a syntactic occurrence of a type description, whereas the latter is just the abstract semantic type. The TypeSourceInfo is right there in the TypeRequirement; the original code was for some reason bypassing it to get the QualType.

    • The final argument is an ordinary bool, not a type (which is what BoolTy is). What does it mean? Tracing the code, it ends up in TypeTraitExprBitFields.Value, whose comment says "If this expression is not value-dependent, this indicates whether the trait evaluated true or false." I think it is irrelevant here, since we have not evaluated the trait expression yet, and in my experiments it made no difference what value I passed.

    Having said all that, I'm not actually sure that constructing a TypeTraitExpr is really the right approach, as it seems rather convoluted. But it works, at least for these simple cases, so I didn't investigate whether there are better alternatives.

    Issue: Only checking some classes?

    The original code only checks SimpleClass against HasValueType. That makes it hard to tell if it works, since the code is supposed to recognize that SimpleClass does satisfy the constraint but X does not. So I added a simple nested loop to check all pairs:

            std::cout << "Testing all classes against all concepts.\n";
            for (auto const &kv1 : ConceptMap) {
                ConceptDecl const *concept = kv1.second;
                std::cout << "- concept " << concept->getNameAsString() << ":\n";
    
                for (auto const &kv2 : ClassMap) {
                    clang::CXXRecordDecl const *clazz = kv2.second;
                    std::cout << "-- class " << clazz->getNameAsString() << ":\n";
    
                    CheckConceptUsage(SemaRef, concept, clazz);
                }
            }
    

    Output after fixes

    With the above fixes, the program works on the example.cpp from the linked gist (it should have been directly in the question):

    template<typename T>
    concept HasValueType = requires {
        typename T::value_type;
        typename T::test;
    };
    
    template <class myType>
    myType GetMax (myType a, int b) {
     return (a>b?a:b);
    }
    
    template<class T>
    class A {
        A(){}
    
        using value_type = int;
    
        int func_A1(T a, int b) {
            return 5;
        }
    };
    
    class SimpleClass {
    public:
        using value_type = int;  // This is the nested type alias that satisfies the concept
        using test = float;
    };
    
    template<class T>
    class B {
        B(){}
    
        int func_A1(T a, int b) {
            return 5;
        }
    };
    
    class X{};
    
    float test(A<float> x, float b, A<B<int>> y) {
    return 4.0;
    }
    
    template<class T>
    float test2(A<T> x, float b, A<B<T>> y) {
    return 4.0;
    }
    
    int main() {
        return 0;
    }
    

    Output excerpt on example.cpp:

    ...
    The class SimpleClass satisfies the concept HasValueType.
    ...
    The class X does NOT satisfy the concept HasValueType.
    

    That output is correct, as SimpleClass has value_type and test member types, whereas X does not.

    I also tested this input, example2.cpp:

    // Simple concept that requires an `asString` method.
    template <typename T>
    concept HasAsString = requires(T v)
    {
      v.asString();
    };
    
    
    // Does not have `asString`.
    struct A {};
    
    // Has `asString`.
    struct B {
      void asString();
    };
    
    
    // Requires the concept to be satisfied.
    template <typename T>
      requires HasAsString<T>
    struct S {};
    
    
    // Does not compile due to unsatisfied constraint.
    //S<A> sa;
    
    static_assert(!HasAsString<A>);
    
    
    // Compiles since the constraint is satisfied.
    S<B> sb;
    
    static_assert(HasAsString<B>);
    

    Its output excerpt is:

    ...
    The class A does NOT satisfy the concept HasAsString.
    ...
    The class B satisfies the concept HasAsString.
    

    Again this is correct, as only B has asString().

    Complete program

    Here is the complete program from the linked gist with fixes applied:

    #include <clang/Tooling/CommonOptionsParser.h>
    #include <clang/Tooling/Tooling.h>
    #include <clang/Frontend/CompilerInstance.h>
    #include <clang/Frontend/FrontendActions.h>
    #include <clang/AST/ASTConsumer.h>
    #include <clang/AST/RecursiveASTVisitor.h>
    #include <clang/AST/Decl.h>
    #include <clang/AST/DeclTemplate.h>
    #include <llvm/Support/CommandLine.h>
    #include <clang/Sema/Sema.h>
    #include <clang/Sema/Template.h>
    #include <clang/AST/ASTContext.h>
    
    #include <iostream>
    
    
    using namespace clang;
    using namespace clang::tooling;
    using namespace llvm;
    
    // Define the option category for command-line options
    static llvm::cl::OptionCategory MyToolCategory("my-tool options");
    
    std::map<std::string, const CXXRecordDecl*> ClassMap;
    std::map<std::string, const ConceptDecl*> ConceptMap;
    
    void CheckConceptUsage(Sema &SemaRef, const ConceptDecl *Concept, const CXXRecordDecl *Class) {
        // Get the type of the class.
        QualType ClassType = SemaRef.Context.getRecordType(Class);
    
        // Create a TemplateArgument representing the class type.
        TemplateArgument ClassTemplateArg(ClassType);
    
        // Prepare the necessary data structures for the constraint check.
        ConstraintSatisfaction Satisfaction;
    
        // Create a MultiLevelTemplateArgumentList
        MultiLevelTemplateArgumentList TemplateArgs;
         ArrayRef<TemplateArgument> TemplateArgsRef(ClassTemplateArg);
        TemplateArgs.addOuterTemplateArguments(const_cast<CXXRecordDecl *>(Class), TemplateArgsRef, /*Final*/ true);
        // TemplateArgs.addOuterTemplateArguments(ArrayRef<TemplateArgument>(ClassTemplateArg));
    
    
        // Retrieve the constraint expression associated with the concept
        const Expr *ConstraintExpr = Concept->getConstraintExpr();
        
        if (!ConstraintExpr) {
            llvm::outs() << "The concept " << Concept->getNameAsString() 
                         << " has no constraints (requires clause) to check.\n";
            return;
        }
    
        // Cast the constraint expression to RequiresExpr to access its components
        if (const RequiresExpr *ReqExpr = llvm::dyn_cast<RequiresExpr>(ConstraintExpr)) {
            std::cout << "--- CheckConceptUsage if " << std::endl;
            // Get the list of requirements (constraints) in the requires expression
            llvm::SmallVector<const Expr*, 4> ConstraintExprs;
    
            for (const auto &Requirement : ReqExpr->getRequirements()) {
                std::cout << "--- CheckConceptUsage for " << std::endl;
                if (const auto *ExprReq = llvm::dyn_cast<clang::concepts::ExprRequirement>(Requirement)) {
                    // Handle expression requirements
                    std::cout << "--- CheckConceptUsage ExprRequirement" << std::endl;
                    ConstraintExprs.push_back(ExprReq->getExpr());
                } else if (const auto *TypeReq = llvm::dyn_cast<clang::concepts::TypeRequirement>(Requirement)) {
                    // Handle type requirements by evaluating the type's instantiation dependency
                    std::cout << "--- CheckConceptUsage TypeRequirement" << std::endl;
                    QualType Type = TypeReq->getType()->getType();
                    QualType DependentType = TypeReq->getType()->getType();
                    if (Type->isDependentType()) {
                        std::cout << "--- CheckConceptUsage isDependentType" << std::endl;
                        // Create a pseudo-expression that checks if this type exists
                        TypeTraitExpr *TraitExpr = TypeTraitExpr::Create(
                            SemaRef.Context, 
                            DependentType,
                            SourceLocation(), 
                            UTT_IsCompleteType,  // Use a type trait like "is complete type"
                            ArrayRef<TypeSourceInfo*>(TypeReq->getType()),
                            SourceLocation(), 
                            false /* Value; meaning is unclear */
                        );
                        
                        ConstraintExprs.push_back(TraitExpr);
                    }
                }
            }
    
            
            std::cout << "--- CheckConceptUsage ConstraintExprs size:" << ConstraintExprs.size() << std::endl;
    
            // Now use the updated list of constraints in the satisfaction check
            //
            // Quoting the documentation of the return value:
            //
            //   "true if an error occurred and satisfaction could not be
            //   checked, false otherwise."
            //
            bool HadError = SemaRef.CheckConstraintSatisfaction(
                Concept, 
                ConstraintExprs, 
                TemplateArgs, 
                Class->getSourceRange(), 
                Satisfaction
            );
    
            if (HadError) {
                llvm::outs() << "CheckConstraintSatisfaction reported an error.\n";
                return;
            }
    
            llvm::outs() << "ContainsErrors: " << Satisfaction.ContainsErrors << "\n";
    
            if (Satisfaction.IsSatisfied) {
                llvm::outs() << "The class " << Class->getName() << " satisfies the concept " << Concept->getName() << ".\n";
            } else {
                llvm::outs() << "The class " << Class->getName() << " does NOT satisfy the concept " << Concept->getName() << ".\n";
            }
        } else {
            llvm::outs() << "The concept " << Concept->getNameAsString() 
                         << " does not have a valid requires expression.\n";
        }
    }
    
    class FunctionConsumer : public ASTConsumer {
    public:
        explicit FunctionConsumer(CompilerInstance &CI): CI(CI) {}
    
        virtual void HandleTranslationUnit(ASTContext &Context) override {
            TranslationUnitDecl *TU = Context.getTranslationUnitDecl();
            this->Context = &Context;
            Sema &SemaRef = CI.getSema();
    
            for (Decl *D : TU->decls()) {
                printDeclKind(D);
                if (CXXRecordDecl *CRD = llvm::dyn_cast<CXXRecordDecl>(D)) {
                    // Handle classes/structs and their methods
                    handleClass(CRD);
                } else if (ConceptDecl *CD = llvm::dyn_cast<ConceptDecl>(D)) {
                    // Handle template classes
                    std::cout << "--- CONCEPT: " << std::endl;
                    printConceptDetails(CD, &Context);
                    std::cout << "--- END CONCEPT: " << std::endl;
                } 
            }
    
            // After collecting all declarations, test a class against a concept
            auto ConceptIter = ConceptMap.find("HasValueType"); // Replace with actual concept name
            auto ClassIter = ClassMap.find("SimpleClass"); // Replace with actual class name
            if (ConceptIter != ConceptMap.end() && ClassIter != ClassMap.end()) {
                CheckConceptUsage(SemaRef, ConceptIter->second, ClassIter->second);
            }
    
            std::cout << "Testing all classes against all concepts.\n";
            for (auto const &kv1 : ConceptMap) {
                ConceptDecl const *concept = kv1.second;
                std::cout << "- concept " << concept->getNameAsString() << ":\n";
    
                for (auto const &kv2 : ClassMap) {
                    clang::CXXRecordDecl const *clazz = kv2.second;
                    std::cout << "-- class " << clazz->getNameAsString() << ":\n";
    
                    CheckConceptUsage(SemaRef, concept, clazz);
                }
            }
        }
    
    private:
        ASTContext *Context;
        CompilerInstance &CI;
    
        void printDeclKind(Decl *D) {
            std::cout << "Declaration Kind: " << D->getDeclKindName() << std::endl;
        }
    
        void printConceptDetails(const clang::ConceptDecl *CD, clang::ASTContext *Context) {
            std::cout << "Concept: " << CD->getNameAsString() << std::endl;
            ConceptMap[CD->getNameAsString()] = CD;
            std::cout << std::endl;
        }
    
    
        void handleClass(CXXRecordDecl *CRD) {
            if (!CRD->isThisDeclarationADefinition()) {
                return; // Skip forward declarations
            }
    
            std::cout << "Class: " << CRD->getNameAsString() << std::endl;
    
            // Store the class in the map
            ClassMap[CRD->getNameAsString()] = CRD;
        }
    };
    
    class FindFunctionsAction : public ASTFrontendAction {
    public:
        std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, StringRef file) override {
            return std::make_unique<FunctionConsumer>(CI);
        }
    };
    
    int main(int argc, const char **argv) {
        auto ExpectedParser = CommonOptionsParser::create(argc, argv, MyToolCategory);
        if (!ExpectedParser) {
            llvm::errs() << ExpectedParser.takeError();
            return 1;
        }
        CommonOptionsParser& OptionsParser = ExpectedParser.get();
        ClangTool Tool(OptionsParser.getCompilations(), OptionsParser.getSourcePathList());
    
        // Manually add the required C++ standard flag
        std::vector<std::string> CompileFlags = {"-std=c++20"};
        Tool.appendArgumentsAdjuster(getInsertArgumentAdjuster(CompileFlags, ArgumentInsertPosition::BEGIN));
    
        return Tool.run(newFrontendActionFactory<FindFunctionsAction>().get());
    }
    

    For completeness, all of my tests were performed using Clang 16.0.0.