USRFindingAction.cpp revision 360660
1//===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Provides an action to find USR for the symbol at <offset>, as well as
11/// all additional USRs.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16#include "clang/AST/AST.h"
17#include "clang/AST/ASTConsumer.h"
18#include "clang/AST/ASTContext.h"
19#include "clang/AST/Decl.h"
20#include "clang/AST/RecursiveASTVisitor.h"
21#include "clang/Basic/FileManager.h"
22#include "clang/Frontend/CompilerInstance.h"
23#include "clang/Frontend/FrontendAction.h"
24#include "clang/Lex/Lexer.h"
25#include "clang/Lex/Preprocessor.h"
26#include "clang/Tooling/CommonOptionsParser.h"
27#include "clang/Tooling/Refactoring.h"
28#include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29#include "clang/Tooling/Tooling.h"
30
31#include <algorithm>
32#include <set>
33#include <string>
34#include <vector>
35
36using namespace llvm;
37
38namespace clang {
39namespace tooling {
40
41const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42  // If FoundDecl is a constructor or destructor, we want to instead take
43  // the Decl of the corresponding class.
44  if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
45    FoundDecl = CtorDecl->getParent();
46  else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
47    FoundDecl = DtorDecl->getParent();
48  // FIXME: (Alex L): Canonicalize implicit template instantions, just like
49  // the indexer does it.
50
51  // Note: please update the declaration's doc comment every time the
52  // canonicalization rules are changed.
53  return FoundDecl;
54}
55
56namespace {
57// NamedDeclFindingConsumer should delegate finding USRs of given Decl to
58// AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
59// Decl refers to class and adds USRs of all overridden methods if Decl refers
60// to virtual method.
61class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
62public:
63  AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
64      : FoundDecl(FoundDecl), Context(Context) {}
65
66  std::vector<std::string> Find() {
67    // Fill OverriddenMethods and PartialSpecs storages.
68    TraverseDecl(Context.getTranslationUnitDecl());
69    if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
70      addUSRsOfOverridenFunctions(MethodDecl);
71      for (const auto &OverriddenMethod : OverriddenMethods) {
72        if (checkIfOverriddenFunctionAscends(OverriddenMethod))
73          USRSet.insert(getUSRForDecl(OverriddenMethod));
74      }
75      addUSRsOfInstantiatedMethods(MethodDecl);
76    } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
77      handleCXXRecordDecl(RecordDecl);
78    } else if (const auto *TemplateDecl =
79                   dyn_cast<ClassTemplateDecl>(FoundDecl)) {
80      handleClassTemplateDecl(TemplateDecl);
81    } else {
82      USRSet.insert(getUSRForDecl(FoundDecl));
83    }
84    return std::vector<std::string>(USRSet.begin(), USRSet.end());
85  }
86
87  bool shouldVisitTemplateInstantiations() const { return true; }
88
89  bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
90    if (MethodDecl->isVirtual())
91      OverriddenMethods.push_back(MethodDecl);
92    if (MethodDecl->getInstantiatedFromMemberFunction())
93      InstantiatedMethods.push_back(MethodDecl);
94    return true;
95  }
96
97  bool VisitClassTemplatePartialSpecializationDecl(
98      const ClassTemplatePartialSpecializationDecl *PartialSpec) {
99    PartialSpecs.push_back(PartialSpec);
100    return true;
101  }
102
103private:
104  void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
105    RecordDecl = RecordDecl->getDefinition();
106    if (const auto *ClassTemplateSpecDecl =
107            dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
108      handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
109    addUSRsOfCtorDtors(RecordDecl);
110  }
111
112  void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
113    for (const auto *Specialization : TemplateDecl->specializations())
114      addUSRsOfCtorDtors(Specialization);
115
116    for (const auto *PartialSpec : PartialSpecs) {
117      if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
118        addUSRsOfCtorDtors(PartialSpec);
119    }
120    addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
121  }
122
123  void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
124    RecordDecl = RecordDecl->getDefinition();
125
126    // Skip if the CXXRecordDecl doesn't have definition.
127    if (!RecordDecl)
128      return;
129
130    for (const auto *CtorDecl : RecordDecl->ctors())
131      USRSet.insert(getUSRForDecl(CtorDecl));
132
133    USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
134    USRSet.insert(getUSRForDecl(RecordDecl));
135  }
136
137  void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
138    USRSet.insert(getUSRForDecl(MethodDecl));
139    // Recursively visit each OverridenMethod.
140    for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
141      addUSRsOfOverridenFunctions(OverriddenMethod);
142  }
143
144  void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
145    // For renaming a class template method, all references of the instantiated
146    // member methods should be renamed too, so add USRs of the instantiated
147    // methods to the USR set.
148    USRSet.insert(getUSRForDecl(MethodDecl));
149    if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
150      USRSet.insert(getUSRForDecl(FT));
151    for (const auto *Method : InstantiatedMethods) {
152      if (USRSet.find(getUSRForDecl(
153              Method->getInstantiatedFromMemberFunction())) != USRSet.end())
154        USRSet.insert(getUSRForDecl(Method));
155    }
156  }
157
158  bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
159    for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
160      if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
161        return true;
162      return checkIfOverriddenFunctionAscends(OverriddenMethod);
163    }
164    return false;
165  }
166
167  const Decl *FoundDecl;
168  ASTContext &Context;
169  std::set<std::string> USRSet;
170  std::vector<const CXXMethodDecl *> OverriddenMethods;
171  std::vector<const CXXMethodDecl *> InstantiatedMethods;
172  std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
173};
174} // namespace
175
176std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
177                                               ASTContext &Context) {
178  AdditionalUSRFinder Finder(ND, Context);
179  return Finder.Find();
180}
181
182class NamedDeclFindingConsumer : public ASTConsumer {
183public:
184  NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
185                           ArrayRef<std::string> QualifiedNames,
186                           std::vector<std::string> &SpellingNames,
187                           std::vector<std::vector<std::string>> &USRList,
188                           bool Force, bool &ErrorOccurred)
189      : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
190        SpellingNames(SpellingNames), USRList(USRList), Force(Force),
191        ErrorOccurred(ErrorOccurred) {}
192
193private:
194  bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
195                  unsigned SymbolOffset, const std::string &QualifiedName) {
196    DiagnosticsEngine &Engine = Context.getDiagnostics();
197    const FileID MainFileID = SourceMgr.getMainFileID();
198
199    if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
200      ErrorOccurred = true;
201      unsigned InvalidOffset = Engine.getCustomDiagID(
202          DiagnosticsEngine::Error,
203          "SourceLocation in file %0 at offset %1 is invalid");
204      Engine.Report(SourceLocation(), InvalidOffset)
205          << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
206      return false;
207    }
208
209    const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
210                                     .getLocWithOffset(SymbolOffset);
211    const NamedDecl *FoundDecl = QualifiedName.empty()
212                                     ? getNamedDeclAt(Context, Point)
213                                     : getNamedDeclFor(Context, QualifiedName);
214
215    if (FoundDecl == nullptr) {
216      if (QualifiedName.empty()) {
217        FullSourceLoc FullLoc(Point, SourceMgr);
218        unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
219            DiagnosticsEngine::Error,
220            "clang-rename could not find symbol (offset %0)");
221        Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
222        ErrorOccurred = true;
223        return false;
224      }
225
226      if (Force) {
227        SpellingNames.push_back(std::string());
228        USRList.push_back(std::vector<std::string>());
229        return true;
230      }
231
232      unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
233          DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
234      Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
235      ErrorOccurred = true;
236      return false;
237    }
238
239    FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
240    SpellingNames.push_back(FoundDecl->getNameAsString());
241    AdditionalUSRFinder Finder(FoundDecl, Context);
242    USRList.push_back(Finder.Find());
243    return true;
244  }
245
246  void HandleTranslationUnit(ASTContext &Context) override {
247    const SourceManager &SourceMgr = Context.getSourceManager();
248    for (unsigned Offset : SymbolOffsets) {
249      if (!FindSymbol(Context, SourceMgr, Offset, ""))
250        return;
251    }
252    for (const std::string &QualifiedName : QualifiedNames) {
253      if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
254        return;
255    }
256  }
257
258  ArrayRef<unsigned> SymbolOffsets;
259  ArrayRef<std::string> QualifiedNames;
260  std::vector<std::string> &SpellingNames;
261  std::vector<std::vector<std::string>> &USRList;
262  bool Force;
263  bool &ErrorOccurred;
264};
265
266std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
267  return llvm::make_unique<NamedDeclFindingConsumer>(
268      SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
269      ErrorOccurred);
270}
271
272} // end namespace tooling
273} // end namespace clang
274