1//===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// Provides an action to find USR for the symbol at <offset>, as well as
11/// all additional USRs.
12///
13//===----------------------------------------------------------------------===//
14
15#include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
16#include "clang/AST/AST.h"
17#include "clang/AST/ASTConsumer.h"
18#include "clang/AST/ASTContext.h"
19#include "clang/AST/Decl.h"
20#include "clang/AST/RecursiveASTVisitor.h"
21#include "clang/Basic/FileManager.h"
22#include "clang/Frontend/CompilerInstance.h"
23#include "clang/Frontend/FrontendAction.h"
24#include "clang/Lex/Lexer.h"
25#include "clang/Lex/Preprocessor.h"
26#include "clang/Tooling/CommonOptionsParser.h"
27#include "clang/Tooling/Refactoring.h"
28#include "clang/Tooling/Refactoring/Rename/USRFinder.h"
29#include "clang/Tooling/Tooling.h"
30
31#include <algorithm>
32#include <set>
33#include <string>
34#include <vector>
35
36using namespace llvm;
37
38namespace clang {
39namespace tooling {
40
41const NamedDecl *getCanonicalSymbolDeclaration(const NamedDecl *FoundDecl) {
42  if (!FoundDecl)
43    return nullptr;
44  // If FoundDecl is a constructor or destructor, we want to instead take
45  // the Decl of the corresponding class.
46  if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
47    FoundDecl = CtorDecl->getParent();
48  else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
49    FoundDecl = DtorDecl->getParent();
50  // FIXME: (Alex L): Canonicalize implicit template instantions, just like
51  // the indexer does it.
52
53  // Note: please update the declaration's doc comment every time the
54  // canonicalization rules are changed.
55  return FoundDecl;
56}
57
58namespace {
59// NamedDeclFindingConsumer should delegate finding USRs of given Decl to
60// AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
61// Decl refers to class and adds USRs of all overridden methods if Decl refers
62// to virtual method.
63class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
64public:
65  AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
66      : FoundDecl(FoundDecl), Context(Context) {}
67
68  std::vector<std::string> Find() {
69    // Fill OverriddenMethods and PartialSpecs storages.
70    TraverseAST(Context);
71    if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
72      addUSRsOfOverridenFunctions(MethodDecl);
73      for (const auto &OverriddenMethod : OverriddenMethods) {
74        if (checkIfOverriddenFunctionAscends(OverriddenMethod))
75          USRSet.insert(getUSRForDecl(OverriddenMethod));
76      }
77      addUSRsOfInstantiatedMethods(MethodDecl);
78    } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
79      handleCXXRecordDecl(RecordDecl);
80    } else if (const auto *TemplateDecl =
81                   dyn_cast<ClassTemplateDecl>(FoundDecl)) {
82      handleClassTemplateDecl(TemplateDecl);
83    } else if (const auto *FD = dyn_cast<FunctionDecl>(FoundDecl)) {
84      USRSet.insert(getUSRForDecl(FD));
85      if (const auto *FTD = FD->getPrimaryTemplate())
86        handleFunctionTemplateDecl(FTD);
87    } else if (const auto *FD = dyn_cast<FunctionTemplateDecl>(FoundDecl)) {
88      handleFunctionTemplateDecl(FD);
89    } else if (const auto *VTD = dyn_cast<VarTemplateDecl>(FoundDecl)) {
90      handleVarTemplateDecl(VTD);
91    } else if (const auto *VD =
92                   dyn_cast<VarTemplateSpecializationDecl>(FoundDecl)) {
93      // FIXME: figure out why FoundDecl can be a VarTemplateSpecializationDecl.
94      handleVarTemplateDecl(VD->getSpecializedTemplate());
95    } else if (const auto *VD = dyn_cast<VarDecl>(FoundDecl)) {
96      USRSet.insert(getUSRForDecl(VD));
97      if (const auto *VTD = VD->getDescribedVarTemplate())
98        handleVarTemplateDecl(VTD);
99    } else {
100      USRSet.insert(getUSRForDecl(FoundDecl));
101    }
102    return std::vector<std::string>(USRSet.begin(), USRSet.end());
103  }
104
105  bool shouldVisitTemplateInstantiations() const { return true; }
106
107  bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
108    if (MethodDecl->isVirtual())
109      OverriddenMethods.push_back(MethodDecl);
110    if (MethodDecl->getInstantiatedFromMemberFunction())
111      InstantiatedMethods.push_back(MethodDecl);
112    return true;
113  }
114
115private:
116  void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
117    if (!RecordDecl->getDefinition()) {
118      USRSet.insert(getUSRForDecl(RecordDecl));
119      return;
120    }
121    RecordDecl = RecordDecl->getDefinition();
122    if (const auto *ClassTemplateSpecDecl =
123            dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
124      handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
125    addUSRsOfCtorDtors(RecordDecl);
126  }
127
128  void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
129    for (const auto *Specialization : TemplateDecl->specializations())
130      addUSRsOfCtorDtors(Specialization);
131    SmallVector<ClassTemplatePartialSpecializationDecl *, 4> PartialSpecs;
132    TemplateDecl->getPartialSpecializations(PartialSpecs);
133    for (const auto *Spec : PartialSpecs)
134      addUSRsOfCtorDtors(Spec);
135    addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
136  }
137
138  void handleFunctionTemplateDecl(const FunctionTemplateDecl *FTD) {
139    USRSet.insert(getUSRForDecl(FTD));
140    USRSet.insert(getUSRForDecl(FTD->getTemplatedDecl()));
141    for (const auto *S : FTD->specializations())
142      USRSet.insert(getUSRForDecl(S));
143  }
144
145  void handleVarTemplateDecl(const VarTemplateDecl *VTD) {
146    USRSet.insert(getUSRForDecl(VTD));
147    USRSet.insert(getUSRForDecl(VTD->getTemplatedDecl()));
148    for (const auto *Spec : VTD->specializations())
149      USRSet.insert(getUSRForDecl(Spec));
150    SmallVector<VarTemplatePartialSpecializationDecl *, 4> PartialSpecs;
151    VTD->getPartialSpecializations(PartialSpecs);
152    for (const auto *Spec : PartialSpecs)
153      USRSet.insert(getUSRForDecl(Spec));
154  }
155
156  void addUSRsOfCtorDtors(const CXXRecordDecl *RD) {
157    const auto* RecordDecl = RD->getDefinition();
158
159    // Skip if the CXXRecordDecl doesn't have definition.
160    if (!RecordDecl) {
161      USRSet.insert(getUSRForDecl(RD));
162      return;
163    }
164
165    for (const auto *CtorDecl : RecordDecl->ctors())
166      USRSet.insert(getUSRForDecl(CtorDecl));
167    // Add template constructor decls, they are not in ctors() unfortunately.
168    if (RecordDecl->hasUserDeclaredConstructor())
169      for (const auto *D : RecordDecl->decls())
170        if (const auto *FTD = dyn_cast<FunctionTemplateDecl>(D))
171          if (const auto *Ctor =
172                  dyn_cast<CXXConstructorDecl>(FTD->getTemplatedDecl()))
173            USRSet.insert(getUSRForDecl(Ctor));
174
175    USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
176    USRSet.insert(getUSRForDecl(RecordDecl));
177  }
178
179  void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
180    USRSet.insert(getUSRForDecl(MethodDecl));
181    // Recursively visit each OverridenMethod.
182    for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
183      addUSRsOfOverridenFunctions(OverriddenMethod);
184  }
185
186  void addUSRsOfInstantiatedMethods(const CXXMethodDecl *MethodDecl) {
187    // For renaming a class template method, all references of the instantiated
188    // member methods should be renamed too, so add USRs of the instantiated
189    // methods to the USR set.
190    USRSet.insert(getUSRForDecl(MethodDecl));
191    if (const auto *FT = MethodDecl->getInstantiatedFromMemberFunction())
192      USRSet.insert(getUSRForDecl(FT));
193    for (const auto *Method : InstantiatedMethods) {
194      if (USRSet.find(getUSRForDecl(
195              Method->getInstantiatedFromMemberFunction())) != USRSet.end())
196        USRSet.insert(getUSRForDecl(Method));
197    }
198  }
199
200  bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
201    for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
202      if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
203        return true;
204      return checkIfOverriddenFunctionAscends(OverriddenMethod);
205    }
206    return false;
207  }
208
209  const Decl *FoundDecl;
210  ASTContext &Context;
211  std::set<std::string> USRSet;
212  std::vector<const CXXMethodDecl *> OverriddenMethods;
213  std::vector<const CXXMethodDecl *> InstantiatedMethods;
214};
215} // namespace
216
217std::vector<std::string> getUSRsForDeclaration(const NamedDecl *ND,
218                                               ASTContext &Context) {
219  AdditionalUSRFinder Finder(ND, Context);
220  return Finder.Find();
221}
222
223class NamedDeclFindingConsumer : public ASTConsumer {
224public:
225  NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
226                           ArrayRef<std::string> QualifiedNames,
227                           std::vector<std::string> &SpellingNames,
228                           std::vector<std::vector<std::string>> &USRList,
229                           bool Force, bool &ErrorOccurred)
230      : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
231        SpellingNames(SpellingNames), USRList(USRList), Force(Force),
232        ErrorOccurred(ErrorOccurred) {}
233
234private:
235  bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
236                  unsigned SymbolOffset, const std::string &QualifiedName) {
237    DiagnosticsEngine &Engine = Context.getDiagnostics();
238    const FileID MainFileID = SourceMgr.getMainFileID();
239
240    if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
241      ErrorOccurred = true;
242      unsigned InvalidOffset = Engine.getCustomDiagID(
243          DiagnosticsEngine::Error,
244          "SourceLocation in file %0 at offset %1 is invalid");
245      Engine.Report(SourceLocation(), InvalidOffset)
246          << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
247      return false;
248    }
249
250    const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
251                                     .getLocWithOffset(SymbolOffset);
252    const NamedDecl *FoundDecl = QualifiedName.empty()
253                                     ? getNamedDeclAt(Context, Point)
254                                     : getNamedDeclFor(Context, QualifiedName);
255
256    if (FoundDecl == nullptr) {
257      if (QualifiedName.empty()) {
258        FullSourceLoc FullLoc(Point, SourceMgr);
259        unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
260            DiagnosticsEngine::Error,
261            "clang-rename could not find symbol (offset %0)");
262        Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
263        ErrorOccurred = true;
264        return false;
265      }
266
267      if (Force) {
268        SpellingNames.push_back(std::string());
269        USRList.push_back(std::vector<std::string>());
270        return true;
271      }
272
273      unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
274          DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
275      Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
276      ErrorOccurred = true;
277      return false;
278    }
279
280    FoundDecl = getCanonicalSymbolDeclaration(FoundDecl);
281    SpellingNames.push_back(FoundDecl->getNameAsString());
282    AdditionalUSRFinder Finder(FoundDecl, Context);
283    USRList.push_back(Finder.Find());
284    return true;
285  }
286
287  void HandleTranslationUnit(ASTContext &Context) override {
288    const SourceManager &SourceMgr = Context.getSourceManager();
289    for (unsigned Offset : SymbolOffsets) {
290      if (!FindSymbol(Context, SourceMgr, Offset, ""))
291        return;
292    }
293    for (const std::string &QualifiedName : QualifiedNames) {
294      if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
295        return;
296    }
297  }
298
299  ArrayRef<unsigned> SymbolOffsets;
300  ArrayRef<std::string> QualifiedNames;
301  std::vector<std::string> &SpellingNames;
302  std::vector<std::vector<std::string>> &USRList;
303  bool Force;
304  bool &ErrorOccurred;
305};
306
307std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
308  return std::make_unique<NamedDeclFindingConsumer>(
309      SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
310      ErrorOccurred);
311}
312
313} // end namespace tooling
314} // end namespace clang
315