1//===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Provides an action to find USR for the symbol at <offset>, as well as
11/// all additional USRs.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16#include "clang/AST/AST.h"
17#include "clang/AST/ASTConsumer.h"
18#include "clang/AST/ASTContext.h"
19#include "clang/AST/Decl.h"
20#include "clang/AST/RecursiveASTVisitor.h"
21#include "clang/Basic/FileManager.h"
22#include "clang/Frontend/CompilerInstance.h"
23#include "clang/Frontend/FrontendAction.h"
24#include "clang/Lex/Lexer.h"
25#include "clang/Lex/Preprocessor.h"
26#include "clang/Tooling/CommonOptionsParser.h"
27#include "clang/Tooling/Refactoring.h"
28#include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29#include "clang/Tooling/Tooling.h"
30
31#include <algorithm>
32#include <set>
33#include <string>
34#include <vector>
35
36using namespace llvm;
37
38namespace clang {
39namespace tooling {
40
41const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42  if (!FoundDecl)
43    return nullptr;
44  // If FoundDecl is a constructor or destructor, we want to instead take
45  // the Decl of the corresponding class.
46  if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47    FoundDecl = CtorDecl->getParent();
48  else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49    FoundDecl = DtorDecl->getParent();
50  // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51  // the indexer does it.
52
53  // Note: please update the declaration's doc comment every time the
54  // canonicalization rules are changed.
55  return FoundDecl;
56}
57
58namespace {
59// NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60// AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61// Decl refers to class and adds USRs of all overridden methods if Decl refers
62// to virtual method.
63class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64public:
65  AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66      : FoundDecl(FoundDecl), Context(Context) {}
67
68  std::vector<std::string> Find() {
69    // Fill OverriddenMethods and PartialSpecs storages.
70    TraverseAST(Context);
71    if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72      addUSRsOfOverridenFunctions(MethodDecl);
73      for (const auto &OverriddenMethod : OverriddenMethods) {
74        if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75          USRSet.insert(getUSRForDecl(OverriddenMethod));
76      }
77      addUSRsOfInstantiatedMethods(MethodDecl);
78    } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79      handleCXXRecordDecl(RecordDecl);
80    } else if (const auto *TemplateDecl =
81                   dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82      handleClassTemplateDecl(TemplateDecl);
83    } else {
84      USRSet.insert(getUSRForDecl(FoundDecl));
85    }
86    return std::vector<std::string>(USRSet.begin(), USRSet.end());
87  }
88
89  bool shouldVisitTemplateInstantiations() const { return true; }
90
91  bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
92    if (MethodDecl->isVirtual())
93      OverriddenMethods.push_back(MethodDecl);
94    if (MethodDecl->getInstantiatedFromMemberFunction())
95      InstantiatedMethods.push_back(MethodDecl);
96    return true;
97  }
98
99  bool VisitClassTemplatePartialSpecializationDecl(
100      const ClassTemplatePartialSpecializationDecl *PartialSpec) {
101    PartialSpecs.push_back(PartialSpec);
102    return true;
103  }
104
105private:
106  void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
107    if (!RecordDecl->getDefinition()) {
108      USRSet.insert(getUSRForDecl(RecordDecl));
109      return;
110    }
111    RecordDecl = RecordDecl->getDefinition();
112    if (const auto *ClassTemplateSpecDecl =
113            dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
114      handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
115    addUSRsOfCtorDtors(RecordDecl);
116  }
117
118  void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
119    for (const auto *Specialization : TemplateDecl->specializations())
120      addUSRsOfCtorDtors(Specialization);
121
122    for (const auto *PartialSpec : PartialSpecs) {
123      if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
124        addUSRsOfCtorDtors(PartialSpec);
125    }
126    addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
127  }
128
129  void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
130    RecordDecl = RecordDecl->getDefinition();
131
132    // Skip if the CXXRecordDecl doesn't have definition.
133    if (!RecordDecl)
134      return;
135
136    for (const auto *CtorDecl : RecordDecl->ctors())
137      USRSet.insert(getUSRForDecl(CtorDecl));
138
139    USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
140    USRSet.insert(getUSRForDecl(RecordDecl));
141  }
142
143  void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
144    USRSet.insert(getUSRForDecl(MethodDecl));
145    // Recursively visit each OverridenMethod.
146    for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
147      addUSRsOfOverridenFunctions(OverriddenMethod);
148  }
149
150  void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
151    // For renaming a class template method, all references of the instantiated
152    // member methods should be renamed too, so add USRs of the instantiated
153    // methods to the USR set.
154    USRSet.insert(getUSRForDecl(MethodDecl));
155    if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
156      USRSet.insert(getUSRForDecl(FT));
157    for (const auto *Method : InstantiatedMethods) {
158      if (USRSet.find(getUSRForDecl(
159              Method->getInstantiatedFromMemberFunction())) != USRSet.end())
160        USRSet.insert(getUSRForDecl(Method));
161    }
162  }
163
164  bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
165    for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
166      if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
167        return true;
168      return checkIfOverriddenFunctionAscends(OverriddenMethod);
169    }
170    return false;
171  }
172
173  const Decl *FoundDecl;
174  ASTContext &Context;
175  std::set<std::string> USRSet;
176  std::vector<const CXXMethodDecl *> OverriddenMethods;
177  std::vector<const CXXMethodDecl *> InstantiatedMethods;
178  std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
179};
180} // namespace
181
182std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
183                                               ASTContext &Context) {
184  AdditionalUSRFinder Finder(ND, Context);
185  return Finder.Find();
186}
187
188class NamedDeclFindingConsumer : public ASTConsumer {
189public:
190  NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
191                           ArrayRef<std::string> QualifiedNames,
192                           std::vector<std::string> &SpellingNames,
193                           std::vector<std::vector<std::string>> &USRList,
194                           bool Force, bool &ErrorOccurred)
195      : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
196        SpellingNames(SpellingNames), USRList(USRList), Force(Force),
197        ErrorOccurred(ErrorOccurred) {}
198
199private:
200  bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
201                  unsigned SymbolOffset, const std::string &QualifiedName) {
202    DiagnosticsEngine &Engine = Context.getDiagnostics();
203    const FileID MainFileID = SourceMgr.getMainFileID();
204
205    if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
206      ErrorOccurred = true;
207      unsigned InvalidOffset = Engine.getCustomDiagID(
208          DiagnosticsEngine::Error,
209          "SourceLocation in file %0 at offset %1 is invalid");
210      Engine.Report(SourceLocation(), InvalidOffset)
211          << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
212      return false;
213    }
214
215    const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
216                                     .getLocWithOffset(SymbolOffset);
217    const NamedDecl *FoundDecl = QualifiedName.empty()
218                                     ? getNamedDeclAt(Context, Point)
219                                     : getNamedDeclFor(Context, QualifiedName);
220
221    if (FoundDecl == nullptr) {
222      if (QualifiedName.empty()) {
223        FullSourceLoc FullLoc(Point, SourceMgr);
224        unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
225            DiagnosticsEngine::Error,
226            "clang-rename could not find symbol (offset %0)");
227        Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
228        ErrorOccurred = true;
229        return false;
230      }
231
232      if (Force) {
233        SpellingNames.push_back(std::string());
234        USRList.push_back(std::vector<std::string>());
235        return true;
236      }
237
238      unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
239          DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
240      Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
241      ErrorOccurred = true;
242      return false;
243    }
244
245    FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
246    SpellingNames.push_back(FoundDecl->getNameAsString());
247    AdditionalUSRFinder Finder(FoundDecl, Context);
248    USRList.push_back(Finder.Find());
249    return true;
250  }
251
252  void HandleTranslationUnit(ASTContext &Context) override {
253    const SourceManager &SourceMgr = Context.getSourceManager();
254    for (unsigned Offset : SymbolOffsets) {
255      if (!FindSymbol(Context, SourceMgr, Offset, ""))
256        return;
257    }
258    for (const std::string &QualifiedName : QualifiedNames) {
259      if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
260        return;
261    }
262  }
263
264  ArrayRef<unsigned> SymbolOffsets;
265  ArrayRef<std::string> QualifiedNames;
266  std::vector<std::string> &SpellingNames;
267  std::vector<std::vector<std::string>> &USRList;
268  bool Force;
269  bool &ErrorOccurred;
270};
271
272std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
273  return std::make_unique<NamedDeclFindingConsumer>(
274      SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
275      ErrorOccurred);
276}
277
278} // end namespace tooling
279} // end namespace clang
280