ASTMatchFinder.cpp revision 239462
1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// Implements an algorithm to efficiently search for matches on AST nodes. 11// Uses memoization to support recursive matches like HasDescendant. 12// 13// The general idea is to visit all AST nodes with a RecursiveASTVisitor, 14// calling the Matches(...) method of each matcher we are running on each 15// AST node. The matcher can recurse via the ASTMatchFinder interface. 16// 17//===----------------------------------------------------------------------===// 18 19#include "clang/ASTMatchers/ASTMatchFinder.h" 20#include "clang/AST/ASTConsumer.h" 21#include "clang/AST/ASTContext.h" 22#include "clang/AST/RecursiveASTVisitor.h" 23#include <set> 24 25namespace clang { 26namespace ast_matchers { 27namespace internal { 28namespace { 29 30// We use memoization to avoid running the same matcher on the same 31// AST node twice. This pair is the key for looking up match 32// result. It consists of an ID of the MatcherInterface (for 33// identifying the matcher) and a pointer to the AST node. 34typedef std::pair<uint64_t, const void*> UntypedMatchInput; 35 36// Used to store the result of a match and possibly bound nodes. 37struct MemoizedMatchResult { 38 bool ResultOfMatch; 39 BoundNodesTree Nodes; 40}; 41 42// A RecursiveASTVisitor that traverses all children or all descendants of 43// a node. 44class MatchChildASTVisitor 45 : public RecursiveASTVisitor<MatchChildASTVisitor> { 46public: 47 typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase; 48 49 // Creates an AST visitor that matches 'matcher' on all children or 50 // descendants of a traversed node. max_depth is the maximum depth 51 // to traverse: use 1 for matching the children and INT_MAX for 52 // matching the descendants. 53 MatchChildASTVisitor(const UntypedBaseMatcher *BaseMatcher, 54 ASTMatchFinder *Finder, 55 BoundNodesTreeBuilder *Builder, 56 int MaxDepth, 57 ASTMatchFinder::TraversalKind Traversal, 58 ASTMatchFinder::BindKind Bind) 59 : BaseMatcher(BaseMatcher), 60 Finder(Finder), 61 Builder(Builder), 62 CurrentDepth(-1), 63 MaxDepth(MaxDepth), 64 Traversal(Traversal), 65 Bind(Bind), 66 Matches(false) {} 67 68 // Returns true if a match is found in the subtree rooted at the 69 // given AST node. This is done via a set of mutually recursive 70 // functions. Here's how the recursion is done (the *wildcard can 71 // actually be Decl, Stmt, or Type): 72 // 73 // - Traverse(node) calls BaseTraverse(node) when it needs 74 // to visit the descendants of node. 75 // - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node)) 76 // Traverse*(c) for each child c of 'node'. 77 // - Traverse*(c) in turn calls Traverse(c), completing the 78 // recursion. 79 template <typename T> 80 bool findMatch(const T &Node) { 81 reset(); 82 traverse(Node); 83 return Matches; 84 } 85 86 // The following are overriding methods from the base visitor class. 87 // They are public only to allow CRTP to work. They are *not *part 88 // of the public API of this class. 89 bool TraverseDecl(Decl *DeclNode) { 90 return (DeclNode == NULL) || traverse(*DeclNode); 91 } 92 bool TraverseStmt(Stmt *StmtNode) { 93 const Stmt *StmtToTraverse = StmtNode; 94 if (Traversal == 95 ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) { 96 const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode); 97 if (ExprNode != NULL) { 98 StmtToTraverse = ExprNode->IgnoreParenImpCasts(); 99 } 100 } 101 return (StmtToTraverse == NULL) || traverse(*StmtToTraverse); 102 } 103 bool TraverseType(QualType TypeNode) { 104 return traverse(TypeNode); 105 } 106 107 bool shouldVisitTemplateInstantiations() const { return true; } 108 bool shouldVisitImplicitCode() const { return true; } 109 110private: 111 // Used for updating the depth during traversal. 112 struct ScopedIncrement { 113 explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); } 114 ~ScopedIncrement() { --(*Depth); } 115 116 private: 117 int *Depth; 118 }; 119 120 // Resets the state of this object. 121 void reset() { 122 Matches = false; 123 CurrentDepth = -1; 124 } 125 126 // Forwards the call to the corresponding Traverse*() method in the 127 // base visitor class. 128 bool baseTraverse(const Decl &DeclNode) { 129 return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode)); 130 } 131 bool baseTraverse(const Stmt &StmtNode) { 132 return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode)); 133 } 134 bool baseTraverse(QualType TypeNode) { 135 return VisitorBase::TraverseType(TypeNode); 136 } 137 138 // Traverses the subtree rooted at 'node'; returns true if the 139 // traversal should continue after this function returns; also sets 140 // matched_ to true if a match is found during the traversal. 141 template <typename T> 142 bool traverse(const T &Node) { 143 TOOLING_COMPILE_ASSERT(IsBaseType<T>::value, 144 traverse_can_only_be_instantiated_with_base_type); 145 ScopedIncrement ScopedDepth(&CurrentDepth); 146 if (CurrentDepth == 0) { 147 // We don't want to match the root node, so just recurse. 148 return baseTraverse(Node); 149 } 150 if (Bind != ASTMatchFinder::BK_All) { 151 if (BaseMatcher->matches(Node, Finder, Builder)) { 152 Matches = true; 153 return false; // Abort as soon as a match is found. 154 } 155 if (CurrentDepth < MaxDepth) { 156 // The current node doesn't match, and we haven't reached the 157 // maximum depth yet, so recurse. 158 return baseTraverse(Node); 159 } 160 // The current node doesn't match, and we have reached the 161 // maximum depth, so don't recurse (but continue the traversal 162 // such that other nodes at the current level can be visited). 163 return true; 164 } else { 165 BoundNodesTreeBuilder RecursiveBuilder; 166 if (BaseMatcher->matches(Node, Finder, &RecursiveBuilder)) { 167 // After the first match the matcher succeeds. 168 Matches = true; 169 Builder->addMatch(RecursiveBuilder.build()); 170 } 171 if (CurrentDepth < MaxDepth) { 172 baseTraverse(Node); 173 } 174 // In kBindAll mode we always search for more matches. 175 return true; 176 } 177 } 178 179 const UntypedBaseMatcher *const BaseMatcher; 180 ASTMatchFinder *const Finder; 181 BoundNodesTreeBuilder *const Builder; 182 int CurrentDepth; 183 const int MaxDepth; 184 const ASTMatchFinder::TraversalKind Traversal; 185 const ASTMatchFinder::BindKind Bind; 186 bool Matches; 187}; 188 189// Controls the outermost traversal of the AST and allows to match multiple 190// matchers. 191class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>, 192 public ASTMatchFinder { 193public: 194 MatchASTVisitor(std::vector< std::pair<const UntypedBaseMatcher*, 195 MatchFinder::MatchCallback*> > *Triggers) 196 : Triggers(Triggers), 197 ActiveASTContext(NULL) { 198 } 199 200 void set_active_ast_context(ASTContext *NewActiveASTContext) { 201 ActiveASTContext = NewActiveASTContext; 202 } 203 204 // The following Visit*() and Traverse*() functions "override" 205 // methods in RecursiveASTVisitor. 206 207 bool VisitTypedefDecl(TypedefDecl *DeclNode) { 208 // When we see 'typedef A B', we add name 'B' to the set of names 209 // A's canonical type maps to. This is necessary for implementing 210 // IsDerivedFrom(x) properly, where x can be the name of the base 211 // class or any of its aliases. 212 // 213 // In general, the is-alias-of (as defined by typedefs) relation 214 // is tree-shaped, as you can typedef a type more than once. For 215 // example, 216 // 217 // typedef A B; 218 // typedef A C; 219 // typedef C D; 220 // typedef C E; 221 // 222 // gives you 223 // 224 // A 225 // |- B 226 // `- C 227 // |- D 228 // `- E 229 // 230 // It is wrong to assume that the relation is a chain. A correct 231 // implementation of IsDerivedFrom() needs to recognize that B and 232 // E are aliases, even though neither is a typedef of the other. 233 // Therefore, we cannot simply walk through one typedef chain to 234 // find out whether the type name matches. 235 const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr(); 236 const Type *CanonicalType = // root of the typedef tree 237 ActiveASTContext->getCanonicalType(TypeNode); 238 TypeAliases[CanonicalType].insert(DeclNode); 239 return true; 240 } 241 242 bool TraverseDecl(Decl *DeclNode); 243 bool TraverseStmt(Stmt *StmtNode); 244 bool TraverseType(QualType TypeNode); 245 bool TraverseTypeLoc(TypeLoc TypeNode); 246 247 // Matches children or descendants of 'Node' with 'BaseMatcher'. 248 template <typename T> 249 bool memoizedMatchesRecursively(const T &Node, 250 const UntypedBaseMatcher &BaseMatcher, 251 BoundNodesTreeBuilder *Builder, int MaxDepth, 252 TraversalKind Traversal, BindKind Bind) { 253 TOOLING_COMPILE_ASSERT((llvm::is_same<T, Decl>::value) || 254 (llvm::is_same<T, Stmt>::value), 255 type_does_not_support_memoization); 256 const UntypedMatchInput input(BaseMatcher.getID(), &Node); 257 std::pair<MemoizationMap::iterator, bool> InsertResult 258 = ResultCache.insert(std::make_pair(input, MemoizedMatchResult())); 259 if (InsertResult.second) { 260 BoundNodesTreeBuilder DescendantBoundNodesBuilder; 261 InsertResult.first->second.ResultOfMatch = 262 matchesRecursively(Node, BaseMatcher, &DescendantBoundNodesBuilder, 263 MaxDepth, Traversal, Bind); 264 InsertResult.first->second.Nodes = 265 DescendantBoundNodesBuilder.build(); 266 } 267 InsertResult.first->second.Nodes.copyTo(Builder); 268 return InsertResult.first->second.ResultOfMatch; 269 } 270 271 // Matches children or descendants of 'Node' with 'BaseMatcher'. 272 template <typename T> 273 bool matchesRecursively(const T &Node, const UntypedBaseMatcher &BaseMatcher, 274 BoundNodesTreeBuilder *Builder, int MaxDepth, 275 TraversalKind Traversal, BindKind Bind) { 276 MatchChildASTVisitor Visitor( 277 &BaseMatcher, this, Builder, MaxDepth, Traversal, Bind); 278 return Visitor.findMatch(Node); 279 } 280 281 virtual bool classIsDerivedFrom(const CXXRecordDecl *Declaration, 282 const Matcher<NamedDecl> &Base, 283 BoundNodesTreeBuilder *Builder); 284 285 // Implements ASTMatchFinder::MatchesChildOf. 286 virtual bool matchesChildOf(const Decl &DeclNode, 287 const UntypedBaseMatcher &BaseMatcher, 288 BoundNodesTreeBuilder *Builder, 289 TraversalKind Traversal, 290 BindKind Bind) { 291 return matchesRecursively(DeclNode, BaseMatcher, Builder, 1, Traversal, 292 Bind); 293 } 294 virtual bool matchesChildOf(const Stmt &StmtNode, 295 const UntypedBaseMatcher &BaseMatcher, 296 BoundNodesTreeBuilder *Builder, 297 TraversalKind Traversal, 298 BindKind Bind) { 299 return matchesRecursively(StmtNode, BaseMatcher, Builder, 1, Traversal, 300 Bind); 301 } 302 303 // Implements ASTMatchFinder::MatchesDescendantOf. 304 virtual bool matchesDescendantOf(const Decl &DeclNode, 305 const UntypedBaseMatcher &BaseMatcher, 306 BoundNodesTreeBuilder *Builder, 307 BindKind Bind) { 308 return memoizedMatchesRecursively(DeclNode, BaseMatcher, Builder, INT_MAX, 309 TK_AsIs, Bind); 310 } 311 virtual bool matchesDescendantOf(const Stmt &StmtNode, 312 const UntypedBaseMatcher &BaseMatcher, 313 BoundNodesTreeBuilder *Builder, 314 BindKind Bind) { 315 return memoizedMatchesRecursively(StmtNode, BaseMatcher, Builder, INT_MAX, 316 TK_AsIs, Bind); 317 } 318 319 bool shouldVisitTemplateInstantiations() const { return true; } 320 bool shouldVisitImplicitCode() const { return true; } 321 322private: 323 // Implements a BoundNodesTree::Visitor that calls a MatchCallback with 324 // the aggregated bound nodes for each match. 325 class MatchVisitor : public BoundNodesTree::Visitor { 326 public: 327 MatchVisitor(ASTContext* Context, 328 MatchFinder::MatchCallback* Callback) 329 : Context(Context), 330 Callback(Callback) {} 331 332 virtual void visitMatch(const BoundNodes& BoundNodesView) { 333 Callback->run(MatchFinder::MatchResult(BoundNodesView, Context)); 334 } 335 336 private: 337 ASTContext* Context; 338 MatchFinder::MatchCallback* Callback; 339 }; 340 341 // Returns true if 'TypeNode' has an alias that matches the given matcher. 342 bool typeHasMatchingAlias(const Type *TypeNode, 343 const Matcher<NamedDecl> Matcher, 344 BoundNodesTreeBuilder *Builder) { 345 const Type *const CanonicalType = 346 ActiveASTContext->getCanonicalType(TypeNode); 347 const std::set<const TypedefDecl*> &Aliases = TypeAliases[CanonicalType]; 348 for (std::set<const TypedefDecl*>::const_iterator 349 It = Aliases.begin(), End = Aliases.end(); 350 It != End; ++It) { 351 if (Matcher.matches(**It, this, Builder)) 352 return true; 353 } 354 return false; 355 } 356 357 // Matches all registered matchers on the given node and calls the 358 // result callback for every node that matches. 359 template <typename T> 360 void match(const T &node) { 361 for (std::vector< std::pair<const UntypedBaseMatcher*, 362 MatchFinder::MatchCallback*> >::const_iterator 363 It = Triggers->begin(), End = Triggers->end(); 364 It != End; ++It) { 365 BoundNodesTreeBuilder Builder; 366 if (It->first->matches(node, this, &Builder)) { 367 BoundNodesTree BoundNodes = Builder.build(); 368 MatchVisitor Visitor(ActiveASTContext, It->second); 369 BoundNodes.visitMatches(&Visitor); 370 } 371 } 372 } 373 374 std::vector< std::pair<const UntypedBaseMatcher*, 375 MatchFinder::MatchCallback*> > *const Triggers; 376 ASTContext *ActiveASTContext; 377 378 // Maps a canonical type to its TypedefDecls. 379 llvm::DenseMap<const Type*, std::set<const TypedefDecl*> > TypeAliases; 380 381 // Maps (matcher, node) -> the match result for memoization. 382 typedef llvm::DenseMap<UntypedMatchInput, MemoizedMatchResult> MemoizationMap; 383 MemoizationMap ResultCache; 384}; 385 386// Returns true if the given class is directly or indirectly derived 387// from a base type with the given name. A class is considered to be 388// also derived from itself. 389bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration, 390 const Matcher<NamedDecl> &Base, 391 BoundNodesTreeBuilder *Builder) { 392 if (Base.matches(*Declaration, this, Builder)) 393 return true; 394 if (!Declaration->hasDefinition()) 395 return false; 396 typedef CXXRecordDecl::base_class_const_iterator BaseIterator; 397 for (BaseIterator It = Declaration->bases_begin(), 398 End = Declaration->bases_end(); It != End; ++It) { 399 const Type *TypeNode = It->getType().getTypePtr(); 400 401 if (typeHasMatchingAlias(TypeNode, Base, Builder)) 402 return true; 403 404 // Type::getAs<...>() drills through typedefs. 405 if (TypeNode->getAs<DependentNameType>() != NULL || 406 TypeNode->getAs<TemplateTypeParmType>() != NULL) 407 // Dependent names and template TypeNode parameters will be matched when 408 // the template is instantiated. 409 continue; 410 CXXRecordDecl *ClassDecl = NULL; 411 TemplateSpecializationType const *TemplateType = 412 TypeNode->getAs<TemplateSpecializationType>(); 413 if (TemplateType != NULL) { 414 if (TemplateType->getTemplateName().isDependent()) 415 // Dependent template specializations will be matched when the 416 // template is instantiated. 417 continue; 418 419 // For template specialization types which are specializing a template 420 // declaration which is an explicit or partial specialization of another 421 // template declaration, getAsCXXRecordDecl() returns the corresponding 422 // ClassTemplateSpecializationDecl. 423 // 424 // For template specialization types which are specializing a template 425 // declaration which is neither an explicit nor partial specialization of 426 // another template declaration, getAsCXXRecordDecl() returns NULL and 427 // we get the CXXRecordDecl of the templated declaration. 428 CXXRecordDecl *SpecializationDecl = 429 TemplateType->getAsCXXRecordDecl(); 430 if (SpecializationDecl != NULL) { 431 ClassDecl = SpecializationDecl; 432 } else { 433 ClassDecl = llvm::dyn_cast<CXXRecordDecl>( 434 TemplateType->getTemplateName() 435 .getAsTemplateDecl()->getTemplatedDecl()); 436 } 437 } else { 438 ClassDecl = TypeNode->getAsCXXRecordDecl(); 439 } 440 assert(ClassDecl != NULL); 441 assert(ClassDecl != Declaration); 442 if (classIsDerivedFrom(ClassDecl, Base, Builder)) 443 return true; 444 } 445 return false; 446} 447 448bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) { 449 if (DeclNode == NULL) { 450 return true; 451 } 452 match(*DeclNode); 453 return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode); 454} 455 456bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) { 457 if (StmtNode == NULL) { 458 return true; 459 } 460 match(*StmtNode); 461 return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode); 462} 463 464bool MatchASTVisitor::TraverseType(QualType TypeNode) { 465 match(TypeNode); 466 return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode); 467} 468 469bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLoc) { 470 match(TypeLoc.getType()); 471 return RecursiveASTVisitor<MatchASTVisitor>:: 472 TraverseTypeLoc(TypeLoc); 473} 474 475class MatchASTConsumer : public ASTConsumer { 476public: 477 MatchASTConsumer(std::vector< std::pair<const UntypedBaseMatcher*, 478 MatchFinder::MatchCallback*> > *Triggers, 479 MatchFinder::ParsingDoneTestCallback *ParsingDone) 480 : Visitor(Triggers), 481 ParsingDone(ParsingDone) {} 482 483private: 484 virtual void HandleTranslationUnit(ASTContext &Context) { 485 if (ParsingDone != NULL) { 486 ParsingDone->run(); 487 } 488 Visitor.set_active_ast_context(&Context); 489 Visitor.TraverseDecl(Context.getTranslationUnitDecl()); 490 Visitor.set_active_ast_context(NULL); 491 } 492 493 MatchASTVisitor Visitor; 494 MatchFinder::ParsingDoneTestCallback *ParsingDone; 495}; 496 497} // end namespace 498} // end namespace internal 499 500MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes, 501 ASTContext *Context) 502 : Nodes(Nodes), Context(Context), 503 SourceManager(&Context->getSourceManager()) {} 504 505MatchFinder::MatchCallback::~MatchCallback() {} 506MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {} 507 508MatchFinder::MatchFinder() : ParsingDone(NULL) {} 509 510MatchFinder::~MatchFinder() { 511 for (std::vector< std::pair<const internal::UntypedBaseMatcher*, 512 MatchFinder::MatchCallback*> >::const_iterator 513 It = Triggers.begin(), End = Triggers.end(); 514 It != End; ++It) { 515 delete It->first; 516 } 517} 518 519void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch, 520 MatchCallback *Action) { 521 Triggers.push_back(std::make_pair( 522 new internal::TypedBaseMatcher<Decl>(NodeMatch), Action)); 523} 524 525void MatchFinder::addMatcher(const TypeMatcher &NodeMatch, 526 MatchCallback *Action) { 527 Triggers.push_back(std::make_pair( 528 new internal::TypedBaseMatcher<QualType>(NodeMatch), Action)); 529} 530 531void MatchFinder::addMatcher(const StatementMatcher &NodeMatch, 532 MatchCallback *Action) { 533 Triggers.push_back(std::make_pair( 534 new internal::TypedBaseMatcher<Stmt>(NodeMatch), Action)); 535} 536 537ASTConsumer *MatchFinder::newASTConsumer() { 538 return new internal::MatchASTConsumer(&Triggers, ParsingDone); 539} 540 541void MatchFinder::registerTestCallbackAfterParsing( 542 MatchFinder::ParsingDoneTestCallback *NewParsingDone) { 543 ParsingDone = NewParsingDone; 544} 545 546} // end namespace ast_matchers 547} // end namespace clang 548