Search code examples
c++clang

Clang Tool that extracts the lambda body given the lambda type


I'm currently trying to implement a Clang Tool based on RecursiveASTVisitors (based on this tutorial) that applies code transformations based on lambdas given to a function.

E.g. generate something based on the lambda given as an argument to foo:

foo([](){});

This is easily possible. Find all callExprs that have the name foo and then find all lambdaExprs that are descendants of this callExpr:

struct LambdaExprVisitor : public RecursiveASTVisitor<LambdaExprVisitor> {
  bool VisitLambdaExpr(LambdaExpr * lambdaExpr) {
    //Do stuff
    lambdaExpr->getCallOperator()->getBody()->dump();
    return true;
  }
};

struct CallVisitor : public RecursiveASTVisitor<CallVisitor> {
  //Find call expressions based on the given name
  bool VisitCallExpr(CallExpr * expr) {
    auto * callee = expr->getDirectCallee(); 
    if(callee != nullptr && callee->getName() == "foo") {
      visitor.TraverseCallExpr(expr);
    }
    return true;
  }
  LambdaExprVisitor visitor;
};

The problem that I now have is that there are multiple ways to get to pass a lambda function to this original function foo, e.g.:

auto bar() { return [](){}; }

int main() {
  foo(bar());
}

And the earlier approach to get the body does not work here.

Now I thought that the lambdas' bodies are known during compile-time and therefore the lambda body must somehow be inferable given the value of the Expr of the given paramter:

struct CallVisitor : public RecursiveASTVisitor<CallVisitor> {
  bool VisitCallExpr(CallExpr * expr) {
    auto * callee = expr->getDirectCallee(); 
    if(callee != nullptr && callee->getName() == "foo") {
      //Get the first argument which must be the lambda
      auto * arg = expr->getArg(0);
      //do something with the the first argument
      //?
    };
    return true;
  }
};

Is there a way to get the lambda body at this point? If not, is there a way to infer the lambda body differently without having to resort to implementing all possible ways to pass a lambda body to foo?

Note: A matcher-based solution would also work for me.


Solution

  • A solution is to first traverse the translation unit and gather all lambdaExprs in a map that has their type as keys. Then, in a second traversal of the translation unit, it is possible to infer the lambda's body by the type. Here are the modified Visitors that now store a reference to this map (The type is encoded as its string representation):

    struct LambdaExprVisitor : public RecursiveASTVisitor<LambdaExprVisitor> {
      LambdaExprVisitor(std::map<std::string, LambdaExpr *>& id2Expr) : RecursiveASTVisitor<LambdaExprVisitor> {}, id2Expr { id2Expr } {} 
    
      std::map<std::string, LambdaExpr *>& id2Expr; 
    
      bool VisitLambdaExpr(LambdaExpr * lambdaExpr) {
        id2Expr.emplace(
          lambdaExpr->getType().getAsString(),
          lambdaExpr
        );
        return true;
      }
    };
    
    struct CallVisitor : public RecursiveASTVisitor<CallVisitor> {
      CallVisitor(std::map<std::string, LambdaExpr *>& id2Expr) : RecursiveASTVisitor<CallVisitor> {}, id2Expr { id2Expr } {} 
    
      std::map<std::string, LambdaExpr *>& id2Expr;
    
      bool VisitCallExpr(CallExpr * expr) {
        auto * callee = expr->getDirectCallee(); 
        if(callee != nullptr && callee->getName() == "foo") {
          //Get the expr from the map
          auto arg = expr->getArg(0)->getType().getAsString();
          if(auto iter = id2Expr.find(arg); iter != id2Expr.end()) {
            //Do stuff with the lambdaExpr
            auto * lambdaExpr = iter->second;
            lambdaExpr->dump();
          }
        };
        return true;
      }
    };
    

    The ASTConsumer to handle this just stores this map and executes the two visitors:

    struct Consumer : public ASTConsumer {
        public: 
    
        virtual void HandleTranslationUnit(clang::ASTContext &Context) override {
          visitor.TraverseDecl(Context.getTranslationUnitDecl());
          visitor2.TraverseDecl(Context.getTranslationUnitDecl());
        }
    
        std::map<std::string, LambdaExpr *> id2Expr; 
    
        LambdaExprVisitor visitor { id2Expr };
        CallVisitor visitor2 { id2Expr };
    
      };