ASTMatchFinder.cpp revision 239462
1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Implements an algorithm to efficiently search for matches on AST nodes.
11//  Uses memoization to support recursive matches like HasDescendant.
12//
13//  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
14//  calling the Matches(...) method of each matcher we are running on each
15//  AST node. The matcher can recurse via the ASTMatchFinder interface.
16//
17//===----------------------------------------------------------------------===//
18
19#include "clang/ASTMatchers/ASTMatchFinder.h"
20#include "clang/AST/ASTConsumer.h"
21#include "clang/AST/ASTContext.h"
22#include "clang/AST/RecursiveASTVisitor.h"
23#include <set>
24
25namespace clang {
26namespace ast_matchers {
27namespace internal {
28namespace {
29
30// We use memoization to avoid running the same matcher on the same
31// AST node twice.  This pair is the key for looking up match
32// result.  It consists of an ID of the MatcherInterface (for
33// identifying the matcher) and a pointer to the AST node.
34typedef std::pair<uint64_t, const void*> UntypedMatchInput;
35
36// Used to store the result of a match and possibly bound nodes.
37struct MemoizedMatchResult {
38  bool ResultOfMatch;
39  BoundNodesTree Nodes;
40};
41
42// A RecursiveASTVisitor that traverses all children or all descendants of
43// a node.
44class MatchChildASTVisitor
45    : public RecursiveASTVisitor<MatchChildASTVisitor> {
46public:
47  typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
48
49  // Creates an AST visitor that matches 'matcher' on all children or
50  // descendants of a traversed node. max_depth is the maximum depth
51  // to traverse: use 1 for matching the children and INT_MAX for
52  // matching the descendants.
53  MatchChildASTVisitor(const UntypedBaseMatcher *BaseMatcher,
54                       ASTMatchFinder *Finder,
55                       BoundNodesTreeBuilder *Builder,
56                       int MaxDepth,
57                       ASTMatchFinder::TraversalKind Traversal,
58                       ASTMatchFinder::BindKind Bind)
59      : BaseMatcher(BaseMatcher),
60        Finder(Finder),
61        Builder(Builder),
62        CurrentDepth(-1),
63        MaxDepth(MaxDepth),
64        Traversal(Traversal),
65        Bind(Bind),
66        Matches(false) {}
67
68  // Returns true if a match is found in the subtree rooted at the
69  // given AST node. This is done via a set of mutually recursive
70  // functions. Here's how the recursion is done (the  *wildcard can
71  // actually be Decl, Stmt, or Type):
72  //
73  //   - Traverse(node) calls BaseTraverse(node) when it needs
74  //     to visit the descendants of node.
75  //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
76  //     Traverse*(c) for each child c of 'node'.
77  //   - Traverse*(c) in turn calls Traverse(c), completing the
78  //     recursion.
79  template <typename T>
80  bool findMatch(const T &Node) {
81    reset();
82    traverse(Node);
83    return Matches;
84  }
85
86  // The following are overriding methods from the base visitor class.
87  // They are public only to allow CRTP to work. They are *not *part
88  // of the public API of this class.
89  bool TraverseDecl(Decl *DeclNode) {
90    return (DeclNode == NULL) || traverse(*DeclNode);
91  }
92  bool TraverseStmt(Stmt *StmtNode) {
93    const Stmt *StmtToTraverse = StmtNode;
94    if (Traversal ==
95        ASTMatchFinder::TK_IgnoreImplicitCastsAndParentheses) {
96      const Expr *ExprNode = dyn_cast_or_null<Expr>(StmtNode);
97      if (ExprNode != NULL) {
98        StmtToTraverse = ExprNode->IgnoreParenImpCasts();
99      }
100    }
101    return (StmtToTraverse == NULL) || traverse(*StmtToTraverse);
102  }
103  bool TraverseType(QualType TypeNode) {
104    return traverse(TypeNode);
105  }
106
107  bool shouldVisitTemplateInstantiations() const { return true; }
108  bool shouldVisitImplicitCode() const { return true; }
109
110private:
111  // Used for updating the depth during traversal.
112  struct ScopedIncrement {
113    explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
114    ~ScopedIncrement() { --(*Depth); }
115
116   private:
117    int *Depth;
118  };
119
120  // Resets the state of this object.
121  void reset() {
122    Matches = false;
123    CurrentDepth = -1;
124  }
125
126  // Forwards the call to the corresponding Traverse*() method in the
127  // base visitor class.
128  bool baseTraverse(const Decl &DeclNode) {
129    return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
130  }
131  bool baseTraverse(const Stmt &StmtNode) {
132    return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
133  }
134  bool baseTraverse(QualType TypeNode) {
135    return VisitorBase::TraverseType(TypeNode);
136  }
137
138  // Traverses the subtree rooted at 'node'; returns true if the
139  // traversal should continue after this function returns; also sets
140  // matched_ to true if a match is found during the traversal.
141  template <typename T>
142  bool traverse(const T &Node) {
143    TOOLING_COMPILE_ASSERT(IsBaseType<T>::value,
144                           traverse_can_only_be_instantiated_with_base_type);
145    ScopedIncrement ScopedDepth(&CurrentDepth);
146    if (CurrentDepth == 0) {
147      // We don't want to match the root node, so just recurse.
148      return baseTraverse(Node);
149    }
150    if (Bind != ASTMatchFinder::BK_All) {
151      if (BaseMatcher->matches(Node, Finder, Builder)) {
152        Matches = true;
153        return false;  // Abort as soon as a match is found.
154      }
155      if (CurrentDepth < MaxDepth) {
156        // The current node doesn't match, and we haven't reached the
157        // maximum depth yet, so recurse.
158        return baseTraverse(Node);
159      }
160      // The current node doesn't match, and we have reached the
161      // maximum depth, so don't recurse (but continue the traversal
162      // such that other nodes at the current level can be visited).
163      return true;
164    } else {
165      BoundNodesTreeBuilder RecursiveBuilder;
166      if (BaseMatcher->matches(Node, Finder, &RecursiveBuilder)) {
167        // After the first match the matcher succeeds.
168        Matches = true;
169        Builder->addMatch(RecursiveBuilder.build());
170      }
171      if (CurrentDepth < MaxDepth) {
172        baseTraverse(Node);
173      }
174      // In kBindAll mode we always search for more matches.
175      return true;
176    }
177  }
178
179  const UntypedBaseMatcher *const BaseMatcher;
180  ASTMatchFinder *const Finder;
181  BoundNodesTreeBuilder *const Builder;
182  int CurrentDepth;
183  const int MaxDepth;
184  const ASTMatchFinder::TraversalKind Traversal;
185  const ASTMatchFinder::BindKind Bind;
186  bool Matches;
187};
188
189// Controls the outermost traversal of the AST and allows to match multiple
190// matchers.
191class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
192                        public ASTMatchFinder {
193public:
194  MatchASTVisitor(std::vector< std::pair<const UntypedBaseMatcher*,
195                               MatchFinder::MatchCallback*> > *Triggers)
196     : Triggers(Triggers),
197       ActiveASTContext(NULL) {
198  }
199
200  void set_active_ast_context(ASTContext *NewActiveASTContext) {
201    ActiveASTContext = NewActiveASTContext;
202  }
203
204  // The following Visit*() and Traverse*() functions "override"
205  // methods in RecursiveASTVisitor.
206
207  bool VisitTypedefDecl(TypedefDecl *DeclNode) {
208    // When we see 'typedef A B', we add name 'B' to the set of names
209    // A's canonical type maps to.  This is necessary for implementing
210    // IsDerivedFrom(x) properly, where x can be the name of the base
211    // class or any of its aliases.
212    //
213    // In general, the is-alias-of (as defined by typedefs) relation
214    // is tree-shaped, as you can typedef a type more than once.  For
215    // example,
216    //
217    //   typedef A B;
218    //   typedef A C;
219    //   typedef C D;
220    //   typedef C E;
221    //
222    // gives you
223    //
224    //   A
225    //   |- B
226    //   `- C
227    //      |- D
228    //      `- E
229    //
230    // It is wrong to assume that the relation is a chain.  A correct
231    // implementation of IsDerivedFrom() needs to recognize that B and
232    // E are aliases, even though neither is a typedef of the other.
233    // Therefore, we cannot simply walk through one typedef chain to
234    // find out whether the type name matches.
235    const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
236    const Type *CanonicalType =  // root of the typedef tree
237        ActiveASTContext->getCanonicalType(TypeNode);
238    TypeAliases[CanonicalType].insert(DeclNode);
239    return true;
240  }
241
242  bool TraverseDecl(Decl *DeclNode);
243  bool TraverseStmt(Stmt *StmtNode);
244  bool TraverseType(QualType TypeNode);
245  bool TraverseTypeLoc(TypeLoc TypeNode);
246
247  // Matches children or descendants of 'Node' with 'BaseMatcher'.
248  template <typename T>
249  bool memoizedMatchesRecursively(const T &Node,
250                                  const UntypedBaseMatcher &BaseMatcher,
251                                  BoundNodesTreeBuilder *Builder, int MaxDepth,
252                                  TraversalKind Traversal, BindKind Bind) {
253    TOOLING_COMPILE_ASSERT((llvm::is_same<T, Decl>::value) ||
254                           (llvm::is_same<T, Stmt>::value),
255                           type_does_not_support_memoization);
256    const UntypedMatchInput input(BaseMatcher.getID(), &Node);
257    std::pair<MemoizationMap::iterator, bool> InsertResult
258      = ResultCache.insert(std::make_pair(input, MemoizedMatchResult()));
259    if (InsertResult.second) {
260      BoundNodesTreeBuilder DescendantBoundNodesBuilder;
261      InsertResult.first->second.ResultOfMatch =
262        matchesRecursively(Node, BaseMatcher, &DescendantBoundNodesBuilder,
263                           MaxDepth, Traversal, Bind);
264      InsertResult.first->second.Nodes =
265        DescendantBoundNodesBuilder.build();
266    }
267    InsertResult.first->second.Nodes.copyTo(Builder);
268    return InsertResult.first->second.ResultOfMatch;
269  }
270
271  // Matches children or descendants of 'Node' with 'BaseMatcher'.
272  template <typename T>
273  bool matchesRecursively(const T &Node, const UntypedBaseMatcher &BaseMatcher,
274                          BoundNodesTreeBuilder *Builder, int MaxDepth,
275                          TraversalKind Traversal, BindKind Bind) {
276    MatchChildASTVisitor Visitor(
277      &BaseMatcher, this, Builder, MaxDepth, Traversal, Bind);
278    return Visitor.findMatch(Node);
279  }
280
281  virtual bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
282                                  const Matcher<NamedDecl> &Base,
283                                  BoundNodesTreeBuilder *Builder);
284
285  // Implements ASTMatchFinder::MatchesChildOf.
286  virtual bool matchesChildOf(const Decl &DeclNode,
287                              const UntypedBaseMatcher &BaseMatcher,
288                              BoundNodesTreeBuilder *Builder,
289                              TraversalKind Traversal,
290                              BindKind Bind) {
291    return matchesRecursively(DeclNode, BaseMatcher, Builder, 1, Traversal,
292                              Bind);
293  }
294  virtual bool matchesChildOf(const Stmt &StmtNode,
295                              const UntypedBaseMatcher &BaseMatcher,
296                              BoundNodesTreeBuilder *Builder,
297                              TraversalKind Traversal,
298                              BindKind Bind) {
299    return matchesRecursively(StmtNode, BaseMatcher, Builder, 1, Traversal,
300                              Bind);
301  }
302
303  // Implements ASTMatchFinder::MatchesDescendantOf.
304  virtual bool matchesDescendantOf(const Decl &DeclNode,
305                                   const UntypedBaseMatcher &BaseMatcher,
306                                   BoundNodesTreeBuilder *Builder,
307                                   BindKind Bind) {
308    return memoizedMatchesRecursively(DeclNode, BaseMatcher, Builder, INT_MAX,
309                                      TK_AsIs, Bind);
310  }
311  virtual bool matchesDescendantOf(const Stmt &StmtNode,
312                                   const UntypedBaseMatcher &BaseMatcher,
313                                   BoundNodesTreeBuilder *Builder,
314                                   BindKind Bind) {
315    return memoizedMatchesRecursively(StmtNode, BaseMatcher, Builder, INT_MAX,
316                                      TK_AsIs, Bind);
317  }
318
319  bool shouldVisitTemplateInstantiations() const { return true; }
320  bool shouldVisitImplicitCode() const { return true; }
321
322private:
323  // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
324  // the aggregated bound nodes for each match.
325  class MatchVisitor : public BoundNodesTree::Visitor {
326  public:
327    MatchVisitor(ASTContext* Context,
328                 MatchFinder::MatchCallback* Callback)
329      : Context(Context),
330        Callback(Callback) {}
331
332    virtual void visitMatch(const BoundNodes& BoundNodesView) {
333      Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
334    }
335
336  private:
337    ASTContext* Context;
338    MatchFinder::MatchCallback* Callback;
339  };
340
341  // Returns true if 'TypeNode' has an alias that matches the given matcher.
342  bool typeHasMatchingAlias(const Type *TypeNode,
343                            const Matcher<NamedDecl> Matcher,
344                            BoundNodesTreeBuilder *Builder) {
345    const Type *const CanonicalType =
346      ActiveASTContext->getCanonicalType(TypeNode);
347    const std::set<const TypedefDecl*> &Aliases = TypeAliases[CanonicalType];
348    for (std::set<const TypedefDecl*>::const_iterator
349           It = Aliases.begin(), End = Aliases.end();
350         It != End; ++It) {
351      if (Matcher.matches(**It, this, Builder))
352        return true;
353    }
354    return false;
355  }
356
357  // Matches all registered matchers on the given node and calls the
358  // result callback for every node that matches.
359  template <typename T>
360  void match(const T &node) {
361    for (std::vector< std::pair<const UntypedBaseMatcher*,
362                      MatchFinder::MatchCallback*> >::const_iterator
363             It = Triggers->begin(), End = Triggers->end();
364         It != End; ++It) {
365      BoundNodesTreeBuilder Builder;
366      if (It->first->matches(node, this, &Builder)) {
367        BoundNodesTree BoundNodes = Builder.build();
368        MatchVisitor Visitor(ActiveASTContext, It->second);
369        BoundNodes.visitMatches(&Visitor);
370      }
371    }
372  }
373
374  std::vector< std::pair<const UntypedBaseMatcher*,
375               MatchFinder::MatchCallback*> > *const Triggers;
376  ASTContext *ActiveASTContext;
377
378  // Maps a canonical type to its TypedefDecls.
379  llvm::DenseMap<const Type*, std::set<const TypedefDecl*> > TypeAliases;
380
381  // Maps (matcher, node) -> the match result for memoization.
382  typedef llvm::DenseMap<UntypedMatchInput, MemoizedMatchResult> MemoizationMap;
383  MemoizationMap ResultCache;
384};
385
386// Returns true if the given class is directly or indirectly derived
387// from a base type with the given name.  A class is considered to be
388// also derived from itself.
389bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
390                                         const Matcher<NamedDecl> &Base,
391                                         BoundNodesTreeBuilder *Builder) {
392  if (Base.matches(*Declaration, this, Builder))
393    return true;
394  if (!Declaration->hasDefinition())
395    return false;
396  typedef CXXRecordDecl::base_class_const_iterator BaseIterator;
397  for (BaseIterator It = Declaration->bases_begin(),
398                    End = Declaration->bases_end(); It != End; ++It) {
399    const Type *TypeNode = It->getType().getTypePtr();
400
401    if (typeHasMatchingAlias(TypeNode, Base, Builder))
402      return true;
403
404    // Type::getAs<...>() drills through typedefs.
405    if (TypeNode->getAs<DependentNameType>() != NULL ||
406        TypeNode->getAs<TemplateTypeParmType>() != NULL)
407      // Dependent names and template TypeNode parameters will be matched when
408      // the template is instantiated.
409      continue;
410    CXXRecordDecl *ClassDecl = NULL;
411    TemplateSpecializationType const *TemplateType =
412      TypeNode->getAs<TemplateSpecializationType>();
413    if (TemplateType != NULL) {
414      if (TemplateType->getTemplateName().isDependent())
415        // Dependent template specializations will be matched when the
416        // template is instantiated.
417        continue;
418
419      // For template specialization types which are specializing a template
420      // declaration which is an explicit or partial specialization of another
421      // template declaration, getAsCXXRecordDecl() returns the corresponding
422      // ClassTemplateSpecializationDecl.
423      //
424      // For template specialization types which are specializing a template
425      // declaration which is neither an explicit nor partial specialization of
426      // another template declaration, getAsCXXRecordDecl() returns NULL and
427      // we get the CXXRecordDecl of the templated declaration.
428      CXXRecordDecl *SpecializationDecl =
429        TemplateType->getAsCXXRecordDecl();
430      if (SpecializationDecl != NULL) {
431        ClassDecl = SpecializationDecl;
432      } else {
433        ClassDecl = llvm::dyn_cast<CXXRecordDecl>(
434            TemplateType->getTemplateName()
435                .getAsTemplateDecl()->getTemplatedDecl());
436      }
437    } else {
438      ClassDecl = TypeNode->getAsCXXRecordDecl();
439    }
440    assert(ClassDecl != NULL);
441    assert(ClassDecl != Declaration);
442    if (classIsDerivedFrom(ClassDecl, Base, Builder))
443      return true;
444  }
445  return false;
446}
447
448bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
449  if (DeclNode == NULL) {
450    return true;
451  }
452  match(*DeclNode);
453  return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
454}
455
456bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode) {
457  if (StmtNode == NULL) {
458    return true;
459  }
460  match(*StmtNode);
461  return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode);
462}
463
464bool MatchASTVisitor::TraverseType(QualType TypeNode) {
465  match(TypeNode);
466  return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
467}
468
469bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLoc) {
470  match(TypeLoc.getType());
471  return RecursiveASTVisitor<MatchASTVisitor>::
472      TraverseTypeLoc(TypeLoc);
473}
474
475class MatchASTConsumer : public ASTConsumer {
476public:
477  MatchASTConsumer(std::vector< std::pair<const UntypedBaseMatcher*,
478                                MatchFinder::MatchCallback*> > *Triggers,
479                   MatchFinder::ParsingDoneTestCallback *ParsingDone)
480      : Visitor(Triggers),
481        ParsingDone(ParsingDone) {}
482
483private:
484  virtual void HandleTranslationUnit(ASTContext &Context) {
485    if (ParsingDone != NULL) {
486      ParsingDone->run();
487    }
488    Visitor.set_active_ast_context(&Context);
489    Visitor.TraverseDecl(Context.getTranslationUnitDecl());
490    Visitor.set_active_ast_context(NULL);
491  }
492
493  MatchASTVisitor Visitor;
494  MatchFinder::ParsingDoneTestCallback *ParsingDone;
495};
496
497} // end namespace
498} // end namespace internal
499
500MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
501                                      ASTContext *Context)
502  : Nodes(Nodes), Context(Context),
503    SourceManager(&Context->getSourceManager()) {}
504
505MatchFinder::MatchCallback::~MatchCallback() {}
506MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
507
508MatchFinder::MatchFinder() : ParsingDone(NULL) {}
509
510MatchFinder::~MatchFinder() {
511  for (std::vector< std::pair<const internal::UntypedBaseMatcher*,
512                    MatchFinder::MatchCallback*> >::const_iterator
513           It = Triggers.begin(), End = Triggers.end();
514       It != End; ++It) {
515    delete It->first;
516  }
517}
518
519void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
520                             MatchCallback *Action) {
521  Triggers.push_back(std::make_pair(
522    new internal::TypedBaseMatcher<Decl>(NodeMatch), Action));
523}
524
525void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
526                             MatchCallback *Action) {
527  Triggers.push_back(std::make_pair(
528    new internal::TypedBaseMatcher<QualType>(NodeMatch), Action));
529}
530
531void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
532                             MatchCallback *Action) {
533  Triggers.push_back(std::make_pair(
534    new internal::TypedBaseMatcher<Stmt>(NodeMatch), Action));
535}
536
537ASTConsumer *MatchFinder::newASTConsumer() {
538  return new internal::MatchASTConsumer(&Triggers, ParsingDone);
539}
540
541void MatchFinder::registerTestCallbackAfterParsing(
542    MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
543  ParsingDone = NewParsingDone;
544}
545
546} // end namespace ast_matchers
547} // end namespace clang
548