1//===--- ASTMatchFinder.cpp - Structural query framework ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//  Implements an algorithm to efficiently search for matches on AST nodes.
10//  Uses memoization to support recursive matches like HasDescendant.
11//
12//  The general idea is to visit all AST nodes with a RecursiveASTVisitor,
13//  calling the Matches(...) method of each matcher we are running on each
14//  AST node. The matcher can recurse via the ASTMatchFinder interface.
15//
16//===----------------------------------------------------------------------===//
17
18#include "clang/ASTMatchers/ASTMatchFinder.h"
19#include "clang/AST/ASTConsumer.h"
20#include "clang/AST/ASTContext.h"
21#include "clang/AST/DeclCXX.h"
22#include "clang/AST/RecursiveASTVisitor.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringMap.h"
26#include "llvm/Support/PrettyStackTrace.h"
27#include "llvm/Support/Timer.h"
28#include <deque>
29#include <memory>
30#include <set>
31
32namespace clang {
33namespace ast_matchers {
34namespace internal {
35namespace {
36
37typedef MatchFinder::MatchCallback MatchCallback;
38
39// The maximum number of memoization entries to store.
40// 10k has been experimentally found to give a good trade-off
41// of performance vs. memory consumption by running matcher
42// that match on every statement over a very large codebase.
43//
44// FIXME: Do some performance optimization in general and
45// revisit this number; also, put up micro-benchmarks that we can
46// optimize this on.
47static const unsigned MaxMemoizationEntries = 10000;
48
49enum class MatchType {
50  Ancestors,
51
52  Descendants,
53  Child,
54};
55
56// We use memoization to avoid running the same matcher on the same
57// AST node twice.  This struct is the key for looking up match
58// result.  It consists of an ID of the MatcherInterface (for
59// identifying the matcher), a pointer to the AST node and the
60// bound nodes before the matcher was executed.
61//
62// We currently only memoize on nodes whose pointers identify the
63// nodes (\c Stmt and \c Decl, but not \c QualType or \c TypeLoc).
64// For \c QualType and \c TypeLoc it is possible to implement
65// generation of keys for each type.
66// FIXME: Benchmark whether memoization of non-pointer typed nodes
67// provides enough benefit for the additional amount of code.
68struct MatchKey {
69  DynTypedMatcher::MatcherIDType MatcherID;
70  DynTypedNode Node;
71  BoundNodesTreeBuilder BoundNodes;
72  TraversalKind Traversal = TK_AsIs;
73  MatchType Type;
74
75  bool operator<(const MatchKey &Other) const {
76    return std::tie(Traversal, Type, MatcherID, Node, BoundNodes) <
77           std::tie(Other.Traversal, Other.Type, Other.MatcherID, Other.Node,
78                    Other.BoundNodes);
79  }
80};
81
82// Used to store the result of a match and possibly bound nodes.
83struct MemoizedMatchResult {
84  bool ResultOfMatch;
85  BoundNodesTreeBuilder Nodes;
86};
87
88// A RecursiveASTVisitor that traverses all children or all descendants of
89// a node.
90class MatchChildASTVisitor
91    : public RecursiveASTVisitor<MatchChildASTVisitor> {
92public:
93  typedef RecursiveASTVisitor<MatchChildASTVisitor> VisitorBase;
94
95  // Creates an AST visitor that matches 'matcher' on all children or
96  // descendants of a traversed node. max_depth is the maximum depth
97  // to traverse: use 1 for matching the children and INT_MAX for
98  // matching the descendants.
99  MatchChildASTVisitor(const DynTypedMatcher *Matcher, ASTMatchFinder *Finder,
100                       BoundNodesTreeBuilder *Builder, int MaxDepth,
101                       bool IgnoreImplicitChildren,
102                       ASTMatchFinder::BindKind Bind)
103      : Matcher(Matcher), Finder(Finder), Builder(Builder), CurrentDepth(0),
104        MaxDepth(MaxDepth), IgnoreImplicitChildren(IgnoreImplicitChildren),
105        Bind(Bind), Matches(false) {}
106
107  // Returns true if a match is found in the subtree rooted at the
108  // given AST node. This is done via a set of mutually recursive
109  // functions. Here's how the recursion is done (the  *wildcard can
110  // actually be Decl, Stmt, or Type):
111  //
112  //   - Traverse(node) calls BaseTraverse(node) when it needs
113  //     to visit the descendants of node.
114  //   - BaseTraverse(node) then calls (via VisitorBase::Traverse*(node))
115  //     Traverse*(c) for each child c of 'node'.
116  //   - Traverse*(c) in turn calls Traverse(c), completing the
117  //     recursion.
118  bool findMatch(const DynTypedNode &DynNode) {
119    reset();
120    if (const Decl *D = DynNode.get<Decl>())
121      traverse(*D);
122    else if (const Stmt *S = DynNode.get<Stmt>())
123      traverse(*S);
124    else if (const NestedNameSpecifier *NNS =
125             DynNode.get<NestedNameSpecifier>())
126      traverse(*NNS);
127    else if (const NestedNameSpecifierLoc *NNSLoc =
128             DynNode.get<NestedNameSpecifierLoc>())
129      traverse(*NNSLoc);
130    else if (const QualType *Q = DynNode.get<QualType>())
131      traverse(*Q);
132    else if (const TypeLoc *T = DynNode.get<TypeLoc>())
133      traverse(*T);
134    else if (const auto *C = DynNode.get<CXXCtorInitializer>())
135      traverse(*C);
136    else if (const TemplateArgumentLoc *TALoc =
137                 DynNode.get<TemplateArgumentLoc>())
138      traverse(*TALoc);
139    else if (const Attr *A = DynNode.get<Attr>())
140      traverse(*A);
141    // FIXME: Add other base types after adding tests.
142
143    // It's OK to always overwrite the bound nodes, as if there was
144    // no match in this recursive branch, the result set is empty
145    // anyway.
146    *Builder = ResultBindings;
147
148    return Matches;
149  }
150
151  // The following are overriding methods from the base visitor class.
152  // They are public only to allow CRTP to work. They are *not *part
153  // of the public API of this class.
154  bool TraverseDecl(Decl *DeclNode) {
155
156    if (DeclNode && DeclNode->isImplicit() &&
157        Finder->isTraversalIgnoringImplicitNodes())
158      return baseTraverse(*DeclNode);
159
160    ScopedIncrement ScopedDepth(&CurrentDepth);
161    return (DeclNode == nullptr) || traverse(*DeclNode);
162  }
163
164  Stmt *getStmtToTraverse(Stmt *StmtNode) {
165    Stmt *StmtToTraverse = StmtNode;
166    if (auto *ExprNode = dyn_cast_or_null<Expr>(StmtNode)) {
167      auto *LambdaNode = dyn_cast_or_null<LambdaExpr>(StmtNode);
168      if (LambdaNode && Finder->isTraversalIgnoringImplicitNodes())
169        StmtToTraverse = LambdaNode;
170      else
171        StmtToTraverse =
172            Finder->getASTContext().getParentMapContext().traverseIgnored(
173                ExprNode);
174    }
175    return StmtToTraverse;
176  }
177
178  bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr) {
179    // If we need to keep track of the depth, we can't perform data recursion.
180    if (CurrentDepth == 0 || (CurrentDepth <= MaxDepth && MaxDepth < INT_MAX))
181      Queue = nullptr;
182
183    ScopedIncrement ScopedDepth(&CurrentDepth);
184    Stmt *StmtToTraverse = getStmtToTraverse(StmtNode);
185    if (!StmtToTraverse)
186      return true;
187
188    if (IgnoreImplicitChildren && isa<CXXDefaultArgExpr>(StmtNode))
189      return true;
190
191    if (!match(*StmtToTraverse))
192      return false;
193    return VisitorBase::TraverseStmt(StmtToTraverse, Queue);
194  }
195  // We assume that the QualType and the contained type are on the same
196  // hierarchy level. Thus, we try to match either of them.
197  bool TraverseType(QualType TypeNode) {
198    if (TypeNode.isNull())
199      return true;
200    ScopedIncrement ScopedDepth(&CurrentDepth);
201    // Match the Type.
202    if (!match(*TypeNode))
203      return false;
204    // The QualType is matched inside traverse.
205    return traverse(TypeNode);
206  }
207  // We assume that the TypeLoc, contained QualType and contained Type all are
208  // on the same hierarchy level. Thus, we try to match all of them.
209  bool TraverseTypeLoc(TypeLoc TypeLocNode) {
210    if (TypeLocNode.isNull())
211      return true;
212    ScopedIncrement ScopedDepth(&CurrentDepth);
213    // Match the Type.
214    if (!match(*TypeLocNode.getType()))
215      return false;
216    // Match the QualType.
217    if (!match(TypeLocNode.getType()))
218      return false;
219    // The TypeLoc is matched inside traverse.
220    return traverse(TypeLocNode);
221  }
222  bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
223    ScopedIncrement ScopedDepth(&CurrentDepth);
224    return (NNS == nullptr) || traverse(*NNS);
225  }
226  bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS) {
227    if (!NNS)
228      return true;
229    ScopedIncrement ScopedDepth(&CurrentDepth);
230    if (!match(*NNS.getNestedNameSpecifier()))
231      return false;
232    return traverse(NNS);
233  }
234  bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit) {
235    if (!CtorInit)
236      return true;
237    ScopedIncrement ScopedDepth(&CurrentDepth);
238    return traverse(*CtorInit);
239  }
240  bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL) {
241    ScopedIncrement ScopedDepth(&CurrentDepth);
242    return traverse(TAL);
243  }
244  bool TraverseCXXForRangeStmt(CXXForRangeStmt *Node) {
245    if (!Finder->isTraversalIgnoringImplicitNodes())
246      return VisitorBase::TraverseCXXForRangeStmt(Node);
247    if (!Node)
248      return true;
249    ScopedIncrement ScopedDepth(&CurrentDepth);
250    if (auto *Init = Node->getInit())
251      if (!traverse(*Init))
252        return false;
253    if (!match(*Node->getLoopVariable()))
254      return false;
255    if (match(*Node->getRangeInit()))
256      if (!VisitorBase::TraverseStmt(Node->getRangeInit()))
257        return false;
258    if (!match(*Node->getBody()))
259      return false;
260    return VisitorBase::TraverseStmt(Node->getBody());
261  }
262  bool TraverseCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *Node) {
263    if (!Finder->isTraversalIgnoringImplicitNodes())
264      return VisitorBase::TraverseCXXRewrittenBinaryOperator(Node);
265    if (!Node)
266      return true;
267    ScopedIncrement ScopedDepth(&CurrentDepth);
268
269    return match(*Node->getLHS()) && match(*Node->getRHS());
270  }
271  bool TraverseAttr(Attr *A) {
272    if (A == nullptr ||
273        (A->isImplicit() &&
274         Finder->getASTContext().getParentMapContext().getTraversalKind() ==
275             TK_IgnoreUnlessSpelledInSource))
276      return true;
277    ScopedIncrement ScopedDepth(&CurrentDepth);
278    return traverse(*A);
279  }
280  bool TraverseLambdaExpr(LambdaExpr *Node) {
281    if (!Finder->isTraversalIgnoringImplicitNodes())
282      return VisitorBase::TraverseLambdaExpr(Node);
283    if (!Node)
284      return true;
285    ScopedIncrement ScopedDepth(&CurrentDepth);
286
287    for (unsigned I = 0, N = Node->capture_size(); I != N; ++I) {
288      const auto *C = Node->capture_begin() + I;
289      if (!C->isExplicit())
290        continue;
291      if (Node->isInitCapture(C) && !match(*C->getCapturedVar()))
292        return false;
293      if (!match(*Node->capture_init_begin()[I]))
294        return false;
295    }
296
297    if (const auto *TPL = Node->getTemplateParameterList()) {
298      for (const auto *TP : *TPL) {
299        if (!match(*TP))
300          return false;
301      }
302    }
303
304    for (const auto *P : Node->getCallOperator()->parameters()) {
305      if (!match(*P))
306        return false;
307    }
308
309    if (!match(*Node->getBody()))
310      return false;
311
312    return VisitorBase::TraverseStmt(Node->getBody());
313  }
314
315  bool shouldVisitTemplateInstantiations() const { return true; }
316  bool shouldVisitImplicitCode() const { return !IgnoreImplicitChildren; }
317
318private:
319  // Used for updating the depth during traversal.
320  struct ScopedIncrement {
321    explicit ScopedIncrement(int *Depth) : Depth(Depth) { ++(*Depth); }
322    ~ScopedIncrement() { --(*Depth); }
323
324   private:
325    int *Depth;
326  };
327
328  // Resets the state of this object.
329  void reset() {
330    Matches = false;
331    CurrentDepth = 0;
332  }
333
334  // Forwards the call to the corresponding Traverse*() method in the
335  // base visitor class.
336  bool baseTraverse(const Decl &DeclNode) {
337    return VisitorBase::TraverseDecl(const_cast<Decl*>(&DeclNode));
338  }
339  bool baseTraverse(const Stmt &StmtNode) {
340    return VisitorBase::TraverseStmt(const_cast<Stmt*>(&StmtNode));
341  }
342  bool baseTraverse(QualType TypeNode) {
343    return VisitorBase::TraverseType(TypeNode);
344  }
345  bool baseTraverse(TypeLoc TypeLocNode) {
346    return VisitorBase::TraverseTypeLoc(TypeLocNode);
347  }
348  bool baseTraverse(const NestedNameSpecifier &NNS) {
349    return VisitorBase::TraverseNestedNameSpecifier(
350        const_cast<NestedNameSpecifier*>(&NNS));
351  }
352  bool baseTraverse(NestedNameSpecifierLoc NNS) {
353    return VisitorBase::TraverseNestedNameSpecifierLoc(NNS);
354  }
355  bool baseTraverse(const CXXCtorInitializer &CtorInit) {
356    return VisitorBase::TraverseConstructorInitializer(
357        const_cast<CXXCtorInitializer *>(&CtorInit));
358  }
359  bool baseTraverse(TemplateArgumentLoc TAL) {
360    return VisitorBase::TraverseTemplateArgumentLoc(TAL);
361  }
362  bool baseTraverse(const Attr &AttrNode) {
363    return VisitorBase::TraverseAttr(const_cast<Attr *>(&AttrNode));
364  }
365
366  // Sets 'Matched' to true if 'Matcher' matches 'Node' and:
367  //   0 < CurrentDepth <= MaxDepth.
368  //
369  // Returns 'true' if traversal should continue after this function
370  // returns, i.e. if no match is found or 'Bind' is 'BK_All'.
371  template <typename T>
372  bool match(const T &Node) {
373    if (CurrentDepth == 0 || CurrentDepth > MaxDepth) {
374      return true;
375    }
376    if (Bind != ASTMatchFinder::BK_All) {
377      BoundNodesTreeBuilder RecursiveBuilder(*Builder);
378      if (Matcher->matches(DynTypedNode::create(Node), Finder,
379                           &RecursiveBuilder)) {
380        Matches = true;
381        ResultBindings.addMatch(RecursiveBuilder);
382        return false; // Abort as soon as a match is found.
383      }
384    } else {
385      BoundNodesTreeBuilder RecursiveBuilder(*Builder);
386      if (Matcher->matches(DynTypedNode::create(Node), Finder,
387                           &RecursiveBuilder)) {
388        // After the first match the matcher succeeds.
389        Matches = true;
390        ResultBindings.addMatch(RecursiveBuilder);
391      }
392    }
393    return true;
394  }
395
396  // Traverses the subtree rooted at 'Node'; returns true if the
397  // traversal should continue after this function returns.
398  template <typename T>
399  bool traverse(const T &Node) {
400    static_assert(IsBaseType<T>::value,
401                  "traverse can only be instantiated with base type");
402    if (!match(Node))
403      return false;
404    return baseTraverse(Node);
405  }
406
407  const DynTypedMatcher *const Matcher;
408  ASTMatchFinder *const Finder;
409  BoundNodesTreeBuilder *const Builder;
410  BoundNodesTreeBuilder ResultBindings;
411  int CurrentDepth;
412  const int MaxDepth;
413  const bool IgnoreImplicitChildren;
414  const ASTMatchFinder::BindKind Bind;
415  bool Matches;
416};
417
418// Controls the outermost traversal of the AST and allows to match multiple
419// matchers.
420class MatchASTVisitor : public RecursiveASTVisitor<MatchASTVisitor>,
421                        public ASTMatchFinder {
422public:
423  MatchASTVisitor(const MatchFinder::MatchersByType *Matchers,
424                  const MatchFinder::MatchFinderOptions &Options)
425      : Matchers(Matchers), Options(Options), ActiveASTContext(nullptr) {}
426
427  ~MatchASTVisitor() override {
428    if (Options.CheckProfiling) {
429      Options.CheckProfiling->Records = std::move(TimeByBucket);
430    }
431  }
432
433  void onStartOfTranslationUnit() {
434    const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
435    TimeBucketRegion Timer;
436    for (MatchCallback *MC : Matchers->AllCallbacks) {
437      if (EnableCheckProfiling)
438        Timer.setBucket(&TimeByBucket[MC->getID()]);
439      MC->onStartOfTranslationUnit();
440    }
441  }
442
443  void onEndOfTranslationUnit() {
444    const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
445    TimeBucketRegion Timer;
446    for (MatchCallback *MC : Matchers->AllCallbacks) {
447      if (EnableCheckProfiling)
448        Timer.setBucket(&TimeByBucket[MC->getID()]);
449      MC->onEndOfTranslationUnit();
450    }
451  }
452
453  void set_active_ast_context(ASTContext *NewActiveASTContext) {
454    ActiveASTContext = NewActiveASTContext;
455  }
456
457  // The following Visit*() and Traverse*() functions "override"
458  // methods in RecursiveASTVisitor.
459
460  bool VisitTypedefNameDecl(TypedefNameDecl *DeclNode) {
461    // When we see 'typedef A B', we add name 'B' to the set of names
462    // A's canonical type maps to.  This is necessary for implementing
463    // isDerivedFrom(x) properly, where x can be the name of the base
464    // class or any of its aliases.
465    //
466    // In general, the is-alias-of (as defined by typedefs) relation
467    // is tree-shaped, as you can typedef a type more than once.  For
468    // example,
469    //
470    //   typedef A B;
471    //   typedef A C;
472    //   typedef C D;
473    //   typedef C E;
474    //
475    // gives you
476    //
477    //   A
478    //   |- B
479    //   `- C
480    //      |- D
481    //      `- E
482    //
483    // It is wrong to assume that the relation is a chain.  A correct
484    // implementation of isDerivedFrom() needs to recognize that B and
485    // E are aliases, even though neither is a typedef of the other.
486    // Therefore, we cannot simply walk through one typedef chain to
487    // find out whether the type name matches.
488    const Type *TypeNode = DeclNode->getUnderlyingType().getTypePtr();
489    const Type *CanonicalType =  // root of the typedef tree
490        ActiveASTContext->getCanonicalType(TypeNode);
491    TypeAliases[CanonicalType].insert(DeclNode);
492    return true;
493  }
494
495  bool VisitObjCCompatibleAliasDecl(ObjCCompatibleAliasDecl *CAD) {
496    const ObjCInterfaceDecl *InterfaceDecl = CAD->getClassInterface();
497    CompatibleAliases[InterfaceDecl].insert(CAD);
498    return true;
499  }
500
501  bool TraverseDecl(Decl *DeclNode);
502  bool TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue = nullptr);
503  bool TraverseType(QualType TypeNode);
504  bool TraverseTypeLoc(TypeLoc TypeNode);
505  bool TraverseNestedNameSpecifier(NestedNameSpecifier *NNS);
506  bool TraverseNestedNameSpecifierLoc(NestedNameSpecifierLoc NNS);
507  bool TraverseConstructorInitializer(CXXCtorInitializer *CtorInit);
508  bool TraverseTemplateArgumentLoc(TemplateArgumentLoc TAL);
509  bool TraverseAttr(Attr *AttrNode);
510
511  bool dataTraverseNode(Stmt *S, DataRecursionQueue *Queue) {
512    if (auto *RF = dyn_cast<CXXForRangeStmt>(S)) {
513      {
514        ASTNodeNotAsIsSourceScope RAII(this, true);
515        TraverseStmt(RF->getInit());
516        // Don't traverse under the loop variable
517        match(*RF->getLoopVariable());
518        TraverseStmt(RF->getRangeInit());
519      }
520      {
521        ASTNodeNotSpelledInSourceScope RAII(this, true);
522        for (auto *SubStmt : RF->children()) {
523          if (SubStmt != RF->getBody())
524            TraverseStmt(SubStmt);
525        }
526      }
527      TraverseStmt(RF->getBody());
528      return true;
529    } else if (auto *RBO = dyn_cast<CXXRewrittenBinaryOperator>(S)) {
530      {
531        ASTNodeNotAsIsSourceScope RAII(this, true);
532        TraverseStmt(const_cast<Expr *>(RBO->getLHS()));
533        TraverseStmt(const_cast<Expr *>(RBO->getRHS()));
534      }
535      {
536        ASTNodeNotSpelledInSourceScope RAII(this, true);
537        for (auto *SubStmt : RBO->children()) {
538          TraverseStmt(SubStmt);
539        }
540      }
541      return true;
542    } else if (auto *LE = dyn_cast<LambdaExpr>(S)) {
543      for (auto I : llvm::zip(LE->captures(), LE->capture_inits())) {
544        auto C = std::get<0>(I);
545        ASTNodeNotSpelledInSourceScope RAII(
546            this, TraversingASTNodeNotSpelledInSource || !C.isExplicit());
547        TraverseLambdaCapture(LE, &C, std::get<1>(I));
548      }
549
550      {
551        ASTNodeNotSpelledInSourceScope RAII(this, true);
552        TraverseDecl(LE->getLambdaClass());
553      }
554      {
555        ASTNodeNotAsIsSourceScope RAII(this, true);
556
557        // We need to poke around to find the bits that might be explicitly
558        // written.
559        TypeLoc TL = LE->getCallOperator()->getTypeSourceInfo()->getTypeLoc();
560        FunctionProtoTypeLoc Proto = TL.getAsAdjusted<FunctionProtoTypeLoc>();
561
562        if (auto *TPL = LE->getTemplateParameterList()) {
563          for (NamedDecl *D : *TPL) {
564            TraverseDecl(D);
565          }
566          if (Expr *RequiresClause = TPL->getRequiresClause()) {
567            TraverseStmt(RequiresClause);
568          }
569        }
570
571        if (LE->hasExplicitParameters()) {
572          // Visit parameters.
573          for (ParmVarDecl *Param : Proto.getParams())
574            TraverseDecl(Param);
575        }
576
577        const auto *T = Proto.getTypePtr();
578        for (const auto &E : T->exceptions())
579          TraverseType(E);
580
581        if (Expr *NE = T->getNoexceptExpr())
582          TraverseStmt(NE, Queue);
583
584        if (LE->hasExplicitResultType())
585          TraverseTypeLoc(Proto.getReturnLoc());
586        TraverseStmt(LE->getTrailingRequiresClause());
587      }
588
589      TraverseStmt(LE->getBody());
590      return true;
591    }
592    return RecursiveASTVisitor<MatchASTVisitor>::dataTraverseNode(S, Queue);
593  }
594
595  // Matches children or descendants of 'Node' with 'BaseMatcher'.
596  bool memoizedMatchesRecursively(const DynTypedNode &Node, ASTContext &Ctx,
597                                  const DynTypedMatcher &Matcher,
598                                  BoundNodesTreeBuilder *Builder, int MaxDepth,
599                                  BindKind Bind) {
600    // For AST-nodes that don't have an identity, we can't memoize.
601    if (!Node.getMemoizationData() || !Builder->isComparable())
602      return matchesRecursively(Node, Matcher, Builder, MaxDepth, Bind);
603
604    MatchKey Key;
605    Key.MatcherID = Matcher.getID();
606    Key.Node = Node;
607    // Note that we key on the bindings *before* the match.
608    Key.BoundNodes = *Builder;
609    Key.Traversal = Ctx.getParentMapContext().getTraversalKind();
610    // Memoize result even doing a single-level match, it might be expensive.
611    Key.Type = MaxDepth == 1 ? MatchType::Child : MatchType::Descendants;
612    MemoizationMap::iterator I = ResultCache.find(Key);
613    if (I != ResultCache.end()) {
614      *Builder = I->second.Nodes;
615      return I->second.ResultOfMatch;
616    }
617
618    MemoizedMatchResult Result;
619    Result.Nodes = *Builder;
620    Result.ResultOfMatch =
621        matchesRecursively(Node, Matcher, &Result.Nodes, MaxDepth, Bind);
622
623    MemoizedMatchResult &CachedResult = ResultCache[Key];
624    CachedResult = std::move(Result);
625
626    *Builder = CachedResult.Nodes;
627    return CachedResult.ResultOfMatch;
628  }
629
630  // Matches children or descendants of 'Node' with 'BaseMatcher'.
631  bool matchesRecursively(const DynTypedNode &Node,
632                          const DynTypedMatcher &Matcher,
633                          BoundNodesTreeBuilder *Builder, int MaxDepth,
634                          BindKind Bind) {
635    bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
636                           TraversingASTChildrenNotSpelledInSource;
637
638    bool IgnoreImplicitChildren = false;
639
640    if (isTraversalIgnoringImplicitNodes()) {
641      IgnoreImplicitChildren = true;
642    }
643
644    ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
645
646    MatchChildASTVisitor Visitor(&Matcher, this, Builder, MaxDepth,
647                                 IgnoreImplicitChildren, Bind);
648    return Visitor.findMatch(Node);
649  }
650
651  bool classIsDerivedFrom(const CXXRecordDecl *Declaration,
652                          const Matcher<NamedDecl> &Base,
653                          BoundNodesTreeBuilder *Builder,
654                          bool Directly) override;
655
656private:
657  bool
658  classIsDerivedFromImpl(const CXXRecordDecl *Declaration,
659                         const Matcher<NamedDecl> &Base,
660                         BoundNodesTreeBuilder *Builder, bool Directly,
661                         llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited);
662
663public:
664  bool objcClassIsDerivedFrom(const ObjCInterfaceDecl *Declaration,
665                              const Matcher<NamedDecl> &Base,
666                              BoundNodesTreeBuilder *Builder,
667                              bool Directly) override;
668
669public:
670  // Implements ASTMatchFinder::matchesChildOf.
671  bool matchesChildOf(const DynTypedNode &Node, ASTContext &Ctx,
672                      const DynTypedMatcher &Matcher,
673                      BoundNodesTreeBuilder *Builder, BindKind Bind) override {
674    if (ResultCache.size() > MaxMemoizationEntries)
675      ResultCache.clear();
676    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, 1, Bind);
677  }
678  // Implements ASTMatchFinder::matchesDescendantOf.
679  bool matchesDescendantOf(const DynTypedNode &Node, ASTContext &Ctx,
680                           const DynTypedMatcher &Matcher,
681                           BoundNodesTreeBuilder *Builder,
682                           BindKind Bind) override {
683    if (ResultCache.size() > MaxMemoizationEntries)
684      ResultCache.clear();
685    return memoizedMatchesRecursively(Node, Ctx, Matcher, Builder, INT_MAX,
686                                      Bind);
687  }
688  // Implements ASTMatchFinder::matchesAncestorOf.
689  bool matchesAncestorOf(const DynTypedNode &Node, ASTContext &Ctx,
690                         const DynTypedMatcher &Matcher,
691                         BoundNodesTreeBuilder *Builder,
692                         AncestorMatchMode MatchMode) override {
693    // Reset the cache outside of the recursive call to make sure we
694    // don't invalidate any iterators.
695    if (ResultCache.size() > MaxMemoizationEntries)
696      ResultCache.clear();
697    if (MatchMode == AncestorMatchMode::AMM_ParentOnly)
698      return matchesParentOf(Node, Matcher, Builder);
699    return matchesAnyAncestorOf(Node, Ctx, Matcher, Builder);
700  }
701
702  // Matches all registered matchers on the given node and calls the
703  // result callback for every node that matches.
704  void match(const DynTypedNode &Node) {
705    // FIXME: Improve this with a switch or a visitor pattern.
706    if (auto *N = Node.get<Decl>()) {
707      match(*N);
708    } else if (auto *N = Node.get<Stmt>()) {
709      match(*N);
710    } else if (auto *N = Node.get<Type>()) {
711      match(*N);
712    } else if (auto *N = Node.get<QualType>()) {
713      match(*N);
714    } else if (auto *N = Node.get<NestedNameSpecifier>()) {
715      match(*N);
716    } else if (auto *N = Node.get<NestedNameSpecifierLoc>()) {
717      match(*N);
718    } else if (auto *N = Node.get<TypeLoc>()) {
719      match(*N);
720    } else if (auto *N = Node.get<CXXCtorInitializer>()) {
721      match(*N);
722    } else if (auto *N = Node.get<TemplateArgumentLoc>()) {
723      match(*N);
724    } else if (auto *N = Node.get<Attr>()) {
725      match(*N);
726    }
727  }
728
729  template <typename T> void match(const T &Node) {
730    matchDispatch(&Node);
731  }
732
733  // Implements ASTMatchFinder::getASTContext.
734  ASTContext &getASTContext() const override { return *ActiveASTContext; }
735
736  bool shouldVisitTemplateInstantiations() const { return true; }
737  bool shouldVisitImplicitCode() const { return true; }
738
739  // We visit the lambda body explicitly, so instruct the RAV
740  // to not visit it on our behalf too.
741  bool shouldVisitLambdaBody() const { return false; }
742
743  bool IsMatchingInASTNodeNotSpelledInSource() const override {
744    return TraversingASTNodeNotSpelledInSource;
745  }
746  bool isMatchingChildrenNotSpelledInSource() const override {
747    return TraversingASTChildrenNotSpelledInSource;
748  }
749  void setMatchingChildrenNotSpelledInSource(bool Set) override {
750    TraversingASTChildrenNotSpelledInSource = Set;
751  }
752
753  bool IsMatchingInASTNodeNotAsIs() const override {
754    return TraversingASTNodeNotAsIs;
755  }
756
757  bool TraverseTemplateInstantiations(ClassTemplateDecl *D) {
758    ASTNodeNotSpelledInSourceScope RAII(this, true);
759    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
760        D);
761  }
762
763  bool TraverseTemplateInstantiations(VarTemplateDecl *D) {
764    ASTNodeNotSpelledInSourceScope RAII(this, true);
765    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
766        D);
767  }
768
769  bool TraverseTemplateInstantiations(FunctionTemplateDecl *D) {
770    ASTNodeNotSpelledInSourceScope RAII(this, true);
771    return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateInstantiations(
772        D);
773  }
774
775private:
776  bool TraversingASTNodeNotSpelledInSource = false;
777  bool TraversingASTNodeNotAsIs = false;
778  bool TraversingASTChildrenNotSpelledInSource = false;
779
780  class CurMatchData {
781// We don't have enough free low bits in 32bit builds to discriminate 8 pointer
782// types in PointerUnion. so split the union in 2 using a free bit from the
783// callback pointer.
784#define CMD_TYPES_0                                                            \
785  const QualType *, const TypeLoc *, const NestedNameSpecifier *,              \
786      const NestedNameSpecifierLoc *
787#define CMD_TYPES_1                                                            \
788  const CXXCtorInitializer *, const TemplateArgumentLoc *, const Attr *,       \
789      const DynTypedNode *
790
791#define IMPL(Index)                                                            \
792  template <typename NodeType>                                                 \
793  std::enable_if_t<                                                            \
794      llvm::is_one_of<const NodeType *, CMD_TYPES_##Index>::value>             \
795  SetCallbackAndRawNode(const MatchCallback *CB, const NodeType &N) {          \
796    assertEmpty();                                                             \
797    Callback.setPointerAndInt(CB, Index);                                      \
798    Node##Index = &N;                                                          \
799  }                                                                            \
800                                                                               \
801  template <typename T>                                                        \
802  std::enable_if_t<llvm::is_one_of<const T *, CMD_TYPES_##Index>::value,       \
803                   const T *>                                                  \
804  getNode() const {                                                            \
805    assertHoldsState();                                                        \
806    return Callback.getInt() == (Index) ? Node##Index.dyn_cast<const T *>()    \
807                                        : nullptr;                             \
808  }
809
810  public:
811    CurMatchData() : Node0(nullptr) {}
812
813    IMPL(0)
814    IMPL(1)
815
816    const MatchCallback *getCallback() const { return Callback.getPointer(); }
817
818    void SetBoundNodes(const BoundNodes &BN) {
819      assertHoldsState();
820      BNodes = &BN;
821    }
822
823    void clearBoundNodes() {
824      assertHoldsState();
825      BNodes = nullptr;
826    }
827
828    const BoundNodes *getBoundNodes() const {
829      assertHoldsState();
830      return BNodes;
831    }
832
833    void reset() {
834      assertHoldsState();
835      Callback.setPointerAndInt(nullptr, 0);
836      Node0 = nullptr;
837    }
838
839  private:
840    void assertHoldsState() const {
841      assert(Callback.getPointer() != nullptr && !Node0.isNull());
842    }
843
844    void assertEmpty() const {
845      assert(Callback.getPointer() == nullptr && Node0.isNull() &&
846             BNodes == nullptr);
847    }
848
849    llvm::PointerIntPair<const MatchCallback *, 1> Callback;
850    union {
851      llvm::PointerUnion<CMD_TYPES_0> Node0;
852      llvm::PointerUnion<CMD_TYPES_1> Node1;
853    };
854    const BoundNodes *BNodes = nullptr;
855
856#undef CMD_TYPES_0
857#undef CMD_TYPES_1
858#undef IMPL
859  } CurMatchState;
860
861  struct CurMatchRAII {
862    template <typename NodeType>
863    CurMatchRAII(MatchASTVisitor &MV, const MatchCallback *CB,
864                 const NodeType &NT)
865        : MV(MV) {
866      MV.CurMatchState.SetCallbackAndRawNode(CB, NT);
867    }
868
869    ~CurMatchRAII() { MV.CurMatchState.reset(); }
870
871  private:
872    MatchASTVisitor &MV;
873  };
874
875public:
876  class TraceReporter : llvm::PrettyStackTraceEntry {
877    static void dumpNode(const ASTContext &Ctx, const DynTypedNode &Node,
878                         raw_ostream &OS) {
879      if (const auto *D = Node.get<Decl>()) {
880        OS << D->getDeclKindName() << "Decl ";
881        if (const auto *ND = dyn_cast<NamedDecl>(D)) {
882          ND->printQualifiedName(OS);
883          OS << " : ";
884        } else
885          OS << ": ";
886        D->getSourceRange().print(OS, Ctx.getSourceManager());
887      } else if (const auto *S = Node.get<Stmt>()) {
888        OS << S->getStmtClassName() << " : ";
889        S->getSourceRange().print(OS, Ctx.getSourceManager());
890      } else if (const auto *T = Node.get<Type>()) {
891        OS << T->getTypeClassName() << "Type : ";
892        QualType(T, 0).print(OS, Ctx.getPrintingPolicy());
893      } else if (const auto *QT = Node.get<QualType>()) {
894        OS << "QualType : ";
895        QT->print(OS, Ctx.getPrintingPolicy());
896      } else {
897        OS << Node.getNodeKind().asStringRef() << " : ";
898        Node.getSourceRange().print(OS, Ctx.getSourceManager());
899      }
900    }
901
902    static void dumpNodeFromState(const ASTContext &Ctx,
903                                  const CurMatchData &State, raw_ostream &OS) {
904      if (const DynTypedNode *MatchNode = State.getNode<DynTypedNode>()) {
905        dumpNode(Ctx, *MatchNode, OS);
906      } else if (const auto *QT = State.getNode<QualType>()) {
907        dumpNode(Ctx, DynTypedNode::create(*QT), OS);
908      } else if (const auto *TL = State.getNode<TypeLoc>()) {
909        dumpNode(Ctx, DynTypedNode::create(*TL), OS);
910      } else if (const auto *NNS = State.getNode<NestedNameSpecifier>()) {
911        dumpNode(Ctx, DynTypedNode::create(*NNS), OS);
912      } else if (const auto *NNSL = State.getNode<NestedNameSpecifierLoc>()) {
913        dumpNode(Ctx, DynTypedNode::create(*NNSL), OS);
914      } else if (const auto *CtorInit = State.getNode<CXXCtorInitializer>()) {
915        dumpNode(Ctx, DynTypedNode::create(*CtorInit), OS);
916      } else if (const auto *TAL = State.getNode<TemplateArgumentLoc>()) {
917        dumpNode(Ctx, DynTypedNode::create(*TAL), OS);
918      } else if (const auto *At = State.getNode<Attr>()) {
919        dumpNode(Ctx, DynTypedNode::create(*At), OS);
920      }
921    }
922
923  public:
924    TraceReporter(const MatchASTVisitor &MV) : MV(MV) {}
925    void print(raw_ostream &OS) const override {
926      const CurMatchData &State = MV.CurMatchState;
927      const MatchCallback *CB = State.getCallback();
928      if (!CB) {
929        OS << "ASTMatcher: Not currently matching\n";
930        return;
931      }
932
933      assert(MV.ActiveASTContext &&
934             "ActiveASTContext should be set if there is a matched callback");
935
936      ASTContext &Ctx = MV.getASTContext();
937
938      if (const BoundNodes *Nodes = State.getBoundNodes()) {
939        OS << "ASTMatcher: Processing '" << CB->getID() << "' against:\n\t";
940        dumpNodeFromState(Ctx, State, OS);
941        const BoundNodes::IDToNodeMap &Map = Nodes->getMap();
942        if (Map.empty()) {
943          OS << "\nNo bound nodes\n";
944          return;
945        }
946        OS << "\n--- Bound Nodes Begin ---\n";
947        for (const auto &Item : Map) {
948          OS << "    " << Item.first << " - { ";
949          dumpNode(Ctx, Item.second, OS);
950          OS << " }\n";
951        }
952        OS << "--- Bound Nodes End ---\n";
953      } else {
954        OS << "ASTMatcher: Matching '" << CB->getID() << "' against:\n\t";
955        dumpNodeFromState(Ctx, State, OS);
956        OS << '\n';
957      }
958    }
959
960  private:
961    const MatchASTVisitor &MV;
962  };
963
964private:
965  struct ASTNodeNotSpelledInSourceScope {
966    ASTNodeNotSpelledInSourceScope(MatchASTVisitor *V, bool B)
967        : MV(V), MB(V->TraversingASTNodeNotSpelledInSource) {
968      V->TraversingASTNodeNotSpelledInSource = B;
969    }
970    ~ASTNodeNotSpelledInSourceScope() {
971      MV->TraversingASTNodeNotSpelledInSource = MB;
972    }
973
974  private:
975    MatchASTVisitor *MV;
976    bool MB;
977  };
978
979  struct ASTNodeNotAsIsSourceScope {
980    ASTNodeNotAsIsSourceScope(MatchASTVisitor *V, bool B)
981        : MV(V), MB(V->TraversingASTNodeNotAsIs) {
982      V->TraversingASTNodeNotAsIs = B;
983    }
984    ~ASTNodeNotAsIsSourceScope() { MV->TraversingASTNodeNotAsIs = MB; }
985
986  private:
987    MatchASTVisitor *MV;
988    bool MB;
989  };
990
991  class TimeBucketRegion {
992  public:
993    TimeBucketRegion() = default;
994    ~TimeBucketRegion() { setBucket(nullptr); }
995
996    /// Start timing for \p NewBucket.
997    ///
998    /// If there was a bucket already set, it will finish the timing for that
999    /// other bucket.
1000    /// \p NewBucket will be timed until the next call to \c setBucket() or
1001    /// until the \c TimeBucketRegion is destroyed.
1002    /// If \p NewBucket is the same as the currently timed bucket, this call
1003    /// does nothing.
1004    void setBucket(llvm::TimeRecord *NewBucket) {
1005      if (Bucket != NewBucket) {
1006        auto Now = llvm::TimeRecord::getCurrentTime(true);
1007        if (Bucket)
1008          *Bucket += Now;
1009        if (NewBucket)
1010          *NewBucket -= Now;
1011        Bucket = NewBucket;
1012      }
1013    }
1014
1015  private:
1016    llvm::TimeRecord *Bucket = nullptr;
1017  };
1018
1019  /// Runs all the \p Matchers on \p Node.
1020  ///
1021  /// Used by \c matchDispatch() below.
1022  template <typename T, typename MC>
1023  void matchWithoutFilter(const T &Node, const MC &Matchers) {
1024    const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
1025    TimeBucketRegion Timer;
1026    for (const auto &MP : Matchers) {
1027      if (EnableCheckProfiling)
1028        Timer.setBucket(&TimeByBucket[MP.second->getID()]);
1029      BoundNodesTreeBuilder Builder;
1030      CurMatchRAII RAII(*this, MP.second, Node);
1031      if (MP.first.matches(Node, this, &Builder)) {
1032        MatchVisitor Visitor(*this, ActiveASTContext, MP.second);
1033        Builder.visitMatches(&Visitor);
1034      }
1035    }
1036  }
1037
1038  void matchWithFilter(const DynTypedNode &DynNode) {
1039    auto Kind = DynNode.getNodeKind();
1040    auto it = MatcherFiltersMap.find(Kind);
1041    const auto &Filter =
1042        it != MatcherFiltersMap.end() ? it->second : getFilterForKind(Kind);
1043
1044    if (Filter.empty())
1045      return;
1046
1047    const bool EnableCheckProfiling = Options.CheckProfiling.has_value();
1048    TimeBucketRegion Timer;
1049    auto &Matchers = this->Matchers->DeclOrStmt;
1050    for (unsigned short I : Filter) {
1051      auto &MP = Matchers[I];
1052      if (EnableCheckProfiling)
1053        Timer.setBucket(&TimeByBucket[MP.second->getID()]);
1054      BoundNodesTreeBuilder Builder;
1055
1056      {
1057        TraversalKindScope RAII(getASTContext(), MP.first.getTraversalKind());
1058        if (getASTContext().getParentMapContext().traverseIgnored(DynNode) !=
1059            DynNode)
1060          continue;
1061      }
1062
1063      CurMatchRAII RAII(*this, MP.second, DynNode);
1064      if (MP.first.matches(DynNode, this, &Builder)) {
1065        MatchVisitor Visitor(*this, ActiveASTContext, MP.second);
1066        Builder.visitMatches(&Visitor);
1067      }
1068    }
1069  }
1070
1071  const std::vector<unsigned short> &getFilterForKind(ASTNodeKind Kind) {
1072    auto &Filter = MatcherFiltersMap[Kind];
1073    auto &Matchers = this->Matchers->DeclOrStmt;
1074    assert((Matchers.size() < USHRT_MAX) && "Too many matchers.");
1075    for (unsigned I = 0, E = Matchers.size(); I != E; ++I) {
1076      if (Matchers[I].first.canMatchNodesOfKind(Kind)) {
1077        Filter.push_back(I);
1078      }
1079    }
1080    return Filter;
1081  }
1082
1083  /// @{
1084  /// Overloads to pair the different node types to their matchers.
1085  void matchDispatch(const Decl *Node) {
1086    return matchWithFilter(DynTypedNode::create(*Node));
1087  }
1088  void matchDispatch(const Stmt *Node) {
1089    return matchWithFilter(DynTypedNode::create(*Node));
1090  }
1091
1092  void matchDispatch(const Type *Node) {
1093    matchWithoutFilter(QualType(Node, 0), Matchers->Type);
1094  }
1095  void matchDispatch(const TypeLoc *Node) {
1096    matchWithoutFilter(*Node, Matchers->TypeLoc);
1097  }
1098  void matchDispatch(const QualType *Node) {
1099    matchWithoutFilter(*Node, Matchers->Type);
1100  }
1101  void matchDispatch(const NestedNameSpecifier *Node) {
1102    matchWithoutFilter(*Node, Matchers->NestedNameSpecifier);
1103  }
1104  void matchDispatch(const NestedNameSpecifierLoc *Node) {
1105    matchWithoutFilter(*Node, Matchers->NestedNameSpecifierLoc);
1106  }
1107  void matchDispatch(const CXXCtorInitializer *Node) {
1108    matchWithoutFilter(*Node, Matchers->CtorInit);
1109  }
1110  void matchDispatch(const TemplateArgumentLoc *Node) {
1111    matchWithoutFilter(*Node, Matchers->TemplateArgumentLoc);
1112  }
1113  void matchDispatch(const Attr *Node) {
1114    matchWithoutFilter(*Node, Matchers->Attr);
1115  }
1116  void matchDispatch(const void *) { /* Do nothing. */ }
1117  /// @}
1118
1119  // Returns whether a direct parent of \p Node matches \p Matcher.
1120  // Unlike matchesAnyAncestorOf there's no memoization: it doesn't save much.
1121  bool matchesParentOf(const DynTypedNode &Node, const DynTypedMatcher &Matcher,
1122                       BoundNodesTreeBuilder *Builder) {
1123    for (const auto &Parent : ActiveASTContext->getParents(Node)) {
1124      BoundNodesTreeBuilder BuilderCopy = *Builder;
1125      if (Matcher.matches(Parent, this, &BuilderCopy)) {
1126        *Builder = std::move(BuilderCopy);
1127        return true;
1128      }
1129    }
1130    return false;
1131  }
1132
1133  // Returns whether an ancestor of \p Node matches \p Matcher.
1134  //
1135  // The order of matching (which can lead to different nodes being bound in
1136  // case there are multiple matches) is breadth first search.
1137  //
1138  // To allow memoization in the very common case of having deeply nested
1139  // expressions inside a template function, we first walk up the AST, memoizing
1140  // the result of the match along the way, as long as there is only a single
1141  // parent.
1142  //
1143  // Once there are multiple parents, the breadth first search order does not
1144  // allow simple memoization on the ancestors. Thus, we only memoize as long
1145  // as there is a single parent.
1146  //
1147  // We avoid a recursive implementation to prevent excessive stack use on
1148  // very deep ASTs (similarly to RecursiveASTVisitor's data recursion).
1149  bool matchesAnyAncestorOf(DynTypedNode Node, ASTContext &Ctx,
1150                            const DynTypedMatcher &Matcher,
1151                            BoundNodesTreeBuilder *Builder) {
1152
1153    // Memoization keys that can be updated with the result.
1154    // These are the memoizable nodes in the chain of unique parents, which
1155    // terminates when a node has multiple parents, or matches, or is the root.
1156    std::vector<MatchKey> Keys;
1157    // When returning, update the memoization cache.
1158    auto Finish = [&](bool Matched) {
1159      for (const auto &Key : Keys) {
1160        MemoizedMatchResult &CachedResult = ResultCache[Key];
1161        CachedResult.ResultOfMatch = Matched;
1162        CachedResult.Nodes = *Builder;
1163      }
1164      return Matched;
1165    };
1166
1167    // Loop while there's a single parent and we want to attempt memoization.
1168    DynTypedNodeList Parents{ArrayRef<DynTypedNode>()}; // after loop: size != 1
1169    for (;;) {
1170      // A cache key only makes sense if memoization is possible.
1171      if (Builder->isComparable()) {
1172        Keys.emplace_back();
1173        Keys.back().MatcherID = Matcher.getID();
1174        Keys.back().Node = Node;
1175        Keys.back().BoundNodes = *Builder;
1176        Keys.back().Traversal = Ctx.getParentMapContext().getTraversalKind();
1177        Keys.back().Type = MatchType::Ancestors;
1178
1179        // Check the cache.
1180        MemoizationMap::iterator I = ResultCache.find(Keys.back());
1181        if (I != ResultCache.end()) {
1182          Keys.pop_back(); // Don't populate the cache for the matching node!
1183          *Builder = I->second.Nodes;
1184          return Finish(I->second.ResultOfMatch);
1185        }
1186      }
1187
1188      Parents = ActiveASTContext->getParents(Node);
1189      // Either no parents or multiple parents: leave chain+memoize mode and
1190      // enter bfs+forgetful mode.
1191      if (Parents.size() != 1)
1192        break;
1193
1194      // Check the next parent.
1195      Node = *Parents.begin();
1196      BoundNodesTreeBuilder BuilderCopy = *Builder;
1197      if (Matcher.matches(Node, this, &BuilderCopy)) {
1198        *Builder = std::move(BuilderCopy);
1199        return Finish(true);
1200      }
1201    }
1202    // We reached the end of the chain.
1203
1204    if (Parents.empty()) {
1205      // Nodes may have no parents if:
1206      //  a) the node is the TranslationUnitDecl
1207      //  b) we have a limited traversal scope that excludes the parent edges
1208      //  c) there is a bug in the AST, and the node is not reachable
1209      // Usually the traversal scope is the whole AST, which precludes b.
1210      // Bugs are common enough that it's worthwhile asserting when we can.
1211#ifndef NDEBUG
1212      if (!Node.get<TranslationUnitDecl>() &&
1213          /* Traversal scope is full AST if any of the bounds are the TU */
1214          llvm::any_of(ActiveASTContext->getTraversalScope(), [](Decl *D) {
1215            return D->getKind() == Decl::TranslationUnit;
1216          })) {
1217        llvm::errs() << "Tried to match orphan node:\n";
1218        Node.dump(llvm::errs(), *ActiveASTContext);
1219        llvm_unreachable("Parent map should be complete!");
1220      }
1221#endif
1222    } else {
1223      assert(Parents.size() > 1);
1224      // BFS starting from the parents not yet considered.
1225      // Memoization of newly visited nodes is not possible (but we still update
1226      // results for the elements in the chain we found above).
1227      std::deque<DynTypedNode> Queue(Parents.begin(), Parents.end());
1228      llvm::DenseSet<const void *> Visited;
1229      while (!Queue.empty()) {
1230        BoundNodesTreeBuilder BuilderCopy = *Builder;
1231        if (Matcher.matches(Queue.front(), this, &BuilderCopy)) {
1232          *Builder = std::move(BuilderCopy);
1233          return Finish(true);
1234        }
1235        for (const auto &Parent : ActiveASTContext->getParents(Queue.front())) {
1236          // Make sure we do not visit the same node twice.
1237          // Otherwise, we'll visit the common ancestors as often as there
1238          // are splits on the way down.
1239          if (Visited.insert(Parent.getMemoizationData()).second)
1240            Queue.push_back(Parent);
1241        }
1242        Queue.pop_front();
1243      }
1244    }
1245    return Finish(false);
1246  }
1247
1248  // Implements a BoundNodesTree::Visitor that calls a MatchCallback with
1249  // the aggregated bound nodes for each match.
1250  class MatchVisitor : public BoundNodesTreeBuilder::Visitor {
1251    struct CurBoundScope {
1252      CurBoundScope(MatchASTVisitor::CurMatchData &State, const BoundNodes &BN)
1253          : State(State) {
1254        State.SetBoundNodes(BN);
1255      }
1256
1257      ~CurBoundScope() { State.clearBoundNodes(); }
1258
1259    private:
1260      MatchASTVisitor::CurMatchData &State;
1261    };
1262
1263  public:
1264    MatchVisitor(MatchASTVisitor &MV, ASTContext *Context,
1265                 MatchFinder::MatchCallback *Callback)
1266        : State(MV.CurMatchState), Context(Context), Callback(Callback) {}
1267
1268    void visitMatch(const BoundNodes& BoundNodesView) override {
1269      TraversalKindScope RAII(*Context, Callback->getCheckTraversalKind());
1270      CurBoundScope RAII2(State, BoundNodesView);
1271      Callback->run(MatchFinder::MatchResult(BoundNodesView, Context));
1272    }
1273
1274  private:
1275    MatchASTVisitor::CurMatchData &State;
1276    ASTContext* Context;
1277    MatchFinder::MatchCallback* Callback;
1278  };
1279
1280  // Returns true if 'TypeNode' has an alias that matches the given matcher.
1281  bool typeHasMatchingAlias(const Type *TypeNode,
1282                            const Matcher<NamedDecl> &Matcher,
1283                            BoundNodesTreeBuilder *Builder) {
1284    const Type *const CanonicalType =
1285      ActiveASTContext->getCanonicalType(TypeNode);
1286    auto Aliases = TypeAliases.find(CanonicalType);
1287    if (Aliases == TypeAliases.end())
1288      return false;
1289    for (const TypedefNameDecl *Alias : Aliases->second) {
1290      BoundNodesTreeBuilder Result(*Builder);
1291      if (Matcher.matches(*Alias, this, &Result)) {
1292        *Builder = std::move(Result);
1293        return true;
1294      }
1295    }
1296    return false;
1297  }
1298
1299  bool
1300  objcClassHasMatchingCompatibilityAlias(const ObjCInterfaceDecl *InterfaceDecl,
1301                                         const Matcher<NamedDecl> &Matcher,
1302                                         BoundNodesTreeBuilder *Builder) {
1303    auto Aliases = CompatibleAliases.find(InterfaceDecl);
1304    if (Aliases == CompatibleAliases.end())
1305      return false;
1306    for (const ObjCCompatibleAliasDecl *Alias : Aliases->second) {
1307      BoundNodesTreeBuilder Result(*Builder);
1308      if (Matcher.matches(*Alias, this, &Result)) {
1309        *Builder = std::move(Result);
1310        return true;
1311      }
1312    }
1313    return false;
1314  }
1315
1316  /// Bucket to record map.
1317  ///
1318  /// Used to get the appropriate bucket for each matcher.
1319  llvm::StringMap<llvm::TimeRecord> TimeByBucket;
1320
1321  const MatchFinder::MatchersByType *Matchers;
1322
1323  /// Filtered list of matcher indices for each matcher kind.
1324  ///
1325  /// \c Decl and \c Stmt toplevel matchers usually apply to a specific node
1326  /// kind (and derived kinds) so it is a waste to try every matcher on every
1327  /// node.
1328  /// We precalculate a list of matchers that pass the toplevel restrict check.
1329  llvm::DenseMap<ASTNodeKind, std::vector<unsigned short>> MatcherFiltersMap;
1330
1331  const MatchFinder::MatchFinderOptions &Options;
1332  ASTContext *ActiveASTContext;
1333
1334  // Maps a canonical type to its TypedefDecls.
1335  llvm::DenseMap<const Type*, std::set<const TypedefNameDecl*> > TypeAliases;
1336
1337  // Maps an Objective-C interface to its ObjCCompatibleAliasDecls.
1338  llvm::DenseMap<const ObjCInterfaceDecl *,
1339                 llvm::SmallPtrSet<const ObjCCompatibleAliasDecl *, 2>>
1340      CompatibleAliases;
1341
1342  // Maps (matcher, node) -> the match result for memoization.
1343  typedef std::map<MatchKey, MemoizedMatchResult> MemoizationMap;
1344  MemoizationMap ResultCache;
1345};
1346
1347static CXXRecordDecl *
1348getAsCXXRecordDeclOrPrimaryTemplate(const Type *TypeNode) {
1349  if (auto *RD = TypeNode->getAsCXXRecordDecl())
1350    return RD;
1351
1352  // Find the innermost TemplateSpecializationType that isn't an alias template.
1353  auto *TemplateType = TypeNode->getAs<TemplateSpecializationType>();
1354  while (TemplateType && TemplateType->isTypeAlias())
1355    TemplateType =
1356        TemplateType->getAliasedType()->getAs<TemplateSpecializationType>();
1357
1358  // If this is the name of a (dependent) template specialization, use the
1359  // definition of the template, even though it might be specialized later.
1360  if (TemplateType)
1361    if (auto *ClassTemplate = dyn_cast_or_null<ClassTemplateDecl>(
1362          TemplateType->getTemplateName().getAsTemplateDecl()))
1363      return ClassTemplate->getTemplatedDecl();
1364
1365  return nullptr;
1366}
1367
1368// Returns true if the given C++ class is directly or indirectly derived
1369// from a base type with the given name.  A class is not considered to be
1370// derived from itself.
1371bool MatchASTVisitor::classIsDerivedFrom(const CXXRecordDecl *Declaration,
1372                                         const Matcher<NamedDecl> &Base,
1373                                         BoundNodesTreeBuilder *Builder,
1374                                         bool Directly) {
1375  llvm::SmallPtrSet<const CXXRecordDecl *, 8> Visited;
1376  return classIsDerivedFromImpl(Declaration, Base, Builder, Directly, Visited);
1377}
1378
1379bool MatchASTVisitor::classIsDerivedFromImpl(
1380    const CXXRecordDecl *Declaration, const Matcher<NamedDecl> &Base,
1381    BoundNodesTreeBuilder *Builder, bool Directly,
1382    llvm::SmallPtrSetImpl<const CXXRecordDecl *> &Visited) {
1383  if (!Declaration->hasDefinition())
1384    return false;
1385  if (!Visited.insert(Declaration).second)
1386    return false;
1387  for (const auto &It : Declaration->bases()) {
1388    const Type *TypeNode = It.getType().getTypePtr();
1389
1390    if (typeHasMatchingAlias(TypeNode, Base, Builder))
1391      return true;
1392
1393    // FIXME: Going to the primary template here isn't really correct, but
1394    // unfortunately we accept a Decl matcher for the base class not a Type
1395    // matcher, so it's the best thing we can do with our current interface.
1396    CXXRecordDecl *ClassDecl = getAsCXXRecordDeclOrPrimaryTemplate(TypeNode);
1397    if (!ClassDecl)
1398      continue;
1399    if (ClassDecl == Declaration) {
1400      // This can happen for recursive template definitions.
1401      continue;
1402    }
1403    BoundNodesTreeBuilder Result(*Builder);
1404    if (Base.matches(*ClassDecl, this, &Result)) {
1405      *Builder = std::move(Result);
1406      return true;
1407    }
1408    if (!Directly &&
1409        classIsDerivedFromImpl(ClassDecl, Base, Builder, Directly, Visited))
1410      return true;
1411  }
1412  return false;
1413}
1414
1415// Returns true if the given Objective-C class is directly or indirectly
1416// derived from a matching base class. A class is not considered to be derived
1417// from itself.
1418bool MatchASTVisitor::objcClassIsDerivedFrom(
1419    const ObjCInterfaceDecl *Declaration, const Matcher<NamedDecl> &Base,
1420    BoundNodesTreeBuilder *Builder, bool Directly) {
1421  // Check if any of the superclasses of the class match.
1422  for (const ObjCInterfaceDecl *ClassDecl = Declaration->getSuperClass();
1423       ClassDecl != nullptr; ClassDecl = ClassDecl->getSuperClass()) {
1424    // Check if there are any matching compatibility aliases.
1425    if (objcClassHasMatchingCompatibilityAlias(ClassDecl, Base, Builder))
1426      return true;
1427
1428    // Check if there are any matching type aliases.
1429    const Type *TypeNode = ClassDecl->getTypeForDecl();
1430    if (typeHasMatchingAlias(TypeNode, Base, Builder))
1431      return true;
1432
1433    if (Base.matches(*ClassDecl, this, Builder))
1434      return true;
1435
1436    // Not `return false` as a temporary workaround for PR43879.
1437    if (Directly)
1438      break;
1439  }
1440
1441  return false;
1442}
1443
1444bool MatchASTVisitor::TraverseDecl(Decl *DeclNode) {
1445  if (!DeclNode) {
1446    return true;
1447  }
1448
1449  bool ScopedTraversal =
1450      TraversingASTNodeNotSpelledInSource || DeclNode->isImplicit();
1451  bool ScopedChildren = TraversingASTChildrenNotSpelledInSource;
1452
1453  if (const auto *CTSD = dyn_cast<ClassTemplateSpecializationDecl>(DeclNode)) {
1454    auto SK = CTSD->getSpecializationKind();
1455    if (SK == TSK_ExplicitInstantiationDeclaration ||
1456        SK == TSK_ExplicitInstantiationDefinition)
1457      ScopedChildren = true;
1458  } else if (const auto *FD = dyn_cast<FunctionDecl>(DeclNode)) {
1459    if (FD->isDefaulted())
1460      ScopedChildren = true;
1461    if (FD->isTemplateInstantiation())
1462      ScopedTraversal = true;
1463  } else if (isa<BindingDecl>(DeclNode)) {
1464    ScopedChildren = true;
1465  }
1466
1467  ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1468  ASTChildrenNotSpelledInSourceScope RAII2(this, ScopedChildren);
1469
1470  match(*DeclNode);
1471  return RecursiveASTVisitor<MatchASTVisitor>::TraverseDecl(DeclNode);
1472}
1473
1474bool MatchASTVisitor::TraverseStmt(Stmt *StmtNode, DataRecursionQueue *Queue) {
1475  if (!StmtNode) {
1476    return true;
1477  }
1478  bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1479                         TraversingASTChildrenNotSpelledInSource;
1480
1481  ASTNodeNotSpelledInSourceScope RAII(this, ScopedTraversal);
1482  match(*StmtNode);
1483  return RecursiveASTVisitor<MatchASTVisitor>::TraverseStmt(StmtNode, Queue);
1484}
1485
1486bool MatchASTVisitor::TraverseType(QualType TypeNode) {
1487  match(TypeNode);
1488  return RecursiveASTVisitor<MatchASTVisitor>::TraverseType(TypeNode);
1489}
1490
1491bool MatchASTVisitor::TraverseTypeLoc(TypeLoc TypeLocNode) {
1492  // The RecursiveASTVisitor only visits types if they're not within TypeLocs.
1493  // We still want to find those types via matchers, so we match them here. Note
1494  // that the TypeLocs are structurally a shadow-hierarchy to the expressed
1495  // type, so we visit all involved parts of a compound type when matching on
1496  // each TypeLoc.
1497  match(TypeLocNode);
1498  match(TypeLocNode.getType());
1499  return RecursiveASTVisitor<MatchASTVisitor>::TraverseTypeLoc(TypeLocNode);
1500}
1501
1502bool MatchASTVisitor::TraverseNestedNameSpecifier(NestedNameSpecifier *NNS) {
1503  match(*NNS);
1504  return RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifier(NNS);
1505}
1506
1507bool MatchASTVisitor::TraverseNestedNameSpecifierLoc(
1508    NestedNameSpecifierLoc NNS) {
1509  if (!NNS)
1510    return true;
1511
1512  match(NNS);
1513
1514  // We only match the nested name specifier here (as opposed to traversing it)
1515  // because the traversal is already done in the parallel "Loc"-hierarchy.
1516  if (NNS.hasQualifier())
1517    match(*NNS.getNestedNameSpecifier());
1518  return
1519      RecursiveASTVisitor<MatchASTVisitor>::TraverseNestedNameSpecifierLoc(NNS);
1520}
1521
1522bool MatchASTVisitor::TraverseConstructorInitializer(
1523    CXXCtorInitializer *CtorInit) {
1524  if (!CtorInit)
1525    return true;
1526
1527  bool ScopedTraversal = TraversingASTNodeNotSpelledInSource ||
1528                         TraversingASTChildrenNotSpelledInSource;
1529
1530  if (!CtorInit->isWritten())
1531    ScopedTraversal = true;
1532
1533  ASTNodeNotSpelledInSourceScope RAII1(this, ScopedTraversal);
1534
1535  match(*CtorInit);
1536
1537  return RecursiveASTVisitor<MatchASTVisitor>::TraverseConstructorInitializer(
1538      CtorInit);
1539}
1540
1541bool MatchASTVisitor::TraverseTemplateArgumentLoc(TemplateArgumentLoc Loc) {
1542  match(Loc);
1543  return RecursiveASTVisitor<MatchASTVisitor>::TraverseTemplateArgumentLoc(Loc);
1544}
1545
1546bool MatchASTVisitor::TraverseAttr(Attr *AttrNode) {
1547  match(*AttrNode);
1548  return RecursiveASTVisitor<MatchASTVisitor>::TraverseAttr(AttrNode);
1549}
1550
1551class MatchASTConsumer : public ASTConsumer {
1552public:
1553  MatchASTConsumer(MatchFinder *Finder,
1554                   MatchFinder::ParsingDoneTestCallback *ParsingDone)
1555      : Finder(Finder), ParsingDone(ParsingDone) {}
1556
1557private:
1558  void HandleTranslationUnit(ASTContext &Context) override {
1559    if (ParsingDone != nullptr) {
1560      ParsingDone->run();
1561    }
1562    Finder->matchAST(Context);
1563  }
1564
1565  MatchFinder *Finder;
1566  MatchFinder::ParsingDoneTestCallback *ParsingDone;
1567};
1568
1569} // end namespace
1570} // end namespace internal
1571
1572MatchFinder::MatchResult::MatchResult(const BoundNodes &Nodes,
1573                                      ASTContext *Context)
1574  : Nodes(Nodes), Context(Context),
1575    SourceManager(&Context->getSourceManager()) {}
1576
1577MatchFinder::MatchCallback::~MatchCallback() {}
1578MatchFinder::ParsingDoneTestCallback::~ParsingDoneTestCallback() {}
1579
1580MatchFinder::MatchFinder(MatchFinderOptions Options)
1581    : Options(std::move(Options)), ParsingDone(nullptr) {}
1582
1583MatchFinder::~MatchFinder() {}
1584
1585void MatchFinder::addMatcher(const DeclarationMatcher &NodeMatch,
1586                             MatchCallback *Action) {
1587  std::optional<TraversalKind> TK;
1588  if (Action)
1589    TK = Action->getCheckTraversalKind();
1590  if (TK)
1591    Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1592  else
1593    Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1594  Matchers.AllCallbacks.insert(Action);
1595}
1596
1597void MatchFinder::addMatcher(const TypeMatcher &NodeMatch,
1598                             MatchCallback *Action) {
1599  Matchers.Type.emplace_back(NodeMatch, Action);
1600  Matchers.AllCallbacks.insert(Action);
1601}
1602
1603void MatchFinder::addMatcher(const StatementMatcher &NodeMatch,
1604                             MatchCallback *Action) {
1605  std::optional<TraversalKind> TK;
1606  if (Action)
1607    TK = Action->getCheckTraversalKind();
1608  if (TK)
1609    Matchers.DeclOrStmt.emplace_back(traverse(*TK, NodeMatch), Action);
1610  else
1611    Matchers.DeclOrStmt.emplace_back(NodeMatch, Action);
1612  Matchers.AllCallbacks.insert(Action);
1613}
1614
1615void MatchFinder::addMatcher(const NestedNameSpecifierMatcher &NodeMatch,
1616                             MatchCallback *Action) {
1617  Matchers.NestedNameSpecifier.emplace_back(NodeMatch, Action);
1618  Matchers.AllCallbacks.insert(Action);
1619}
1620
1621void MatchFinder::addMatcher(const NestedNameSpecifierLocMatcher &NodeMatch,
1622                             MatchCallback *Action) {
1623  Matchers.NestedNameSpecifierLoc.emplace_back(NodeMatch, Action);
1624  Matchers.AllCallbacks.insert(Action);
1625}
1626
1627void MatchFinder::addMatcher(const TypeLocMatcher &NodeMatch,
1628                             MatchCallback *Action) {
1629  Matchers.TypeLoc.emplace_back(NodeMatch, Action);
1630  Matchers.AllCallbacks.insert(Action);
1631}
1632
1633void MatchFinder::addMatcher(const CXXCtorInitializerMatcher &NodeMatch,
1634                             MatchCallback *Action) {
1635  Matchers.CtorInit.emplace_back(NodeMatch, Action);
1636  Matchers.AllCallbacks.insert(Action);
1637}
1638
1639void MatchFinder::addMatcher(const TemplateArgumentLocMatcher &NodeMatch,
1640                             MatchCallback *Action) {
1641  Matchers.TemplateArgumentLoc.emplace_back(NodeMatch, Action);
1642  Matchers.AllCallbacks.insert(Action);
1643}
1644
1645void MatchFinder::addMatcher(const AttrMatcher &AttrMatch,
1646                             MatchCallback *Action) {
1647  Matchers.Attr.emplace_back(AttrMatch, Action);
1648  Matchers.AllCallbacks.insert(Action);
1649}
1650
1651bool MatchFinder::addDynamicMatcher(const internal::DynTypedMatcher &NodeMatch,
1652                                    MatchCallback *Action) {
1653  if (NodeMatch.canConvertTo<Decl>()) {
1654    addMatcher(NodeMatch.convertTo<Decl>(), Action);
1655    return true;
1656  } else if (NodeMatch.canConvertTo<QualType>()) {
1657    addMatcher(NodeMatch.convertTo<QualType>(), Action);
1658    return true;
1659  } else if (NodeMatch.canConvertTo<Stmt>()) {
1660    addMatcher(NodeMatch.convertTo<Stmt>(), Action);
1661    return true;
1662  } else if (NodeMatch.canConvertTo<NestedNameSpecifier>()) {
1663    addMatcher(NodeMatch.convertTo<NestedNameSpecifier>(), Action);
1664    return true;
1665  } else if (NodeMatch.canConvertTo<NestedNameSpecifierLoc>()) {
1666    addMatcher(NodeMatch.convertTo<NestedNameSpecifierLoc>(), Action);
1667    return true;
1668  } else if (NodeMatch.canConvertTo<TypeLoc>()) {
1669    addMatcher(NodeMatch.convertTo<TypeLoc>(), Action);
1670    return true;
1671  } else if (NodeMatch.canConvertTo<CXXCtorInitializer>()) {
1672    addMatcher(NodeMatch.convertTo<CXXCtorInitializer>(), Action);
1673    return true;
1674  } else if (NodeMatch.canConvertTo<TemplateArgumentLoc>()) {
1675    addMatcher(NodeMatch.convertTo<TemplateArgumentLoc>(), Action);
1676    return true;
1677  } else if (NodeMatch.canConvertTo<Attr>()) {
1678    addMatcher(NodeMatch.convertTo<Attr>(), Action);
1679    return true;
1680  }
1681  return false;
1682}
1683
1684std::unique_ptr<ASTConsumer> MatchFinder::newASTConsumer() {
1685  return std::make_unique<internal::MatchASTConsumer>(this, ParsingDone);
1686}
1687
1688void MatchFinder::match(const clang::DynTypedNode &Node, ASTContext &Context) {
1689  internal::MatchASTVisitor Visitor(&Matchers, Options);
1690  Visitor.set_active_ast_context(&Context);
1691  Visitor.match(Node);
1692}
1693
1694void MatchFinder::matchAST(ASTContext &Context) {
1695  internal::MatchASTVisitor Visitor(&Matchers, Options);
1696  internal::MatchASTVisitor::TraceReporter StackTrace(Visitor);
1697  Visitor.set_active_ast_context(&Context);
1698  Visitor.onStartOfTranslationUnit();
1699  Visitor.TraverseAST(Context);
1700  Visitor.onEndOfTranslationUnit();
1701}
1702
1703void MatchFinder::registerTestCallbackAfterParsing(
1704    MatchFinder::ParsingDoneTestCallback *NewParsingDone) {
1705  ParsingDone = NewParsingDone;
1706}
1707
1708StringRef MatchFinder::MatchCallback::getID() const { return "<unknown>"; }
1709
1710std::optional<TraversalKind>
1711MatchFinder::MatchCallback::getCheckTraversalKind() const {
1712  return std::nullopt;
1713}
1714
1715} // end namespace ast_matchers
1716} // end namespace clang
1717