USRFindingAction.cpp revision 321369
1//===--- USRFindingAction.cpp - Clang refactoring library -----------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9///
10/// \file
11/// \brief Provides an action to find USR for the symbol at <offset>, as well as
12/// all additional USRs.
13///
14//===----------------------------------------------------------------------===//
15
16#include "clang/Tooling/Refactoring/Rename/USRFindingAction.h"
17#include "clang/AST/AST.h"
18#include "clang/AST/ASTConsumer.h"
19#include "clang/AST/ASTContext.h"
20#include "clang/AST/Decl.h"
21#include "clang/AST/RecursiveASTVisitor.h"
22#include "clang/Basic/FileManager.h"
23#include "clang/Frontend/CompilerInstance.h"
24#include "clang/Frontend/FrontendAction.h"
25#include "clang/Lex/Lexer.h"
26#include "clang/Lex/Preprocessor.h"
27#include "clang/Tooling/CommonOptionsParser.h"
28#include "clang/Tooling/Refactoring.h"
29#include "clang/Tooling/Refactoring/Rename/USRFinder.h"
30#include "clang/Tooling/Tooling.h"
31
32#include <algorithm>
33#include <set>
34#include <string>
35#include <vector>
36
37using namespace llvm;
38
39namespace clang {
40namespace tooling {
41
42namespace {
43// \brief NamedDeclFindingConsumer should delegate finding USRs of given Decl to
44// AdditionalUSRFinder. AdditionalUSRFinder adds USRs of ctor and dtor if given
45// Decl refers to class and adds USRs of all overridden methods if Decl refers
46// to virtual method.
47class AdditionalUSRFinder : public RecursiveASTVisitor<AdditionalUSRFinder> {
48public:
49  AdditionalUSRFinder(const Decl *FoundDecl, ASTContext &Context)
50      : FoundDecl(FoundDecl), Context(Context) {}
51
52  std::vector<std::string> Find() {
53    // Fill OverriddenMethods and PartialSpecs storages.
54    TraverseDecl(Context.getTranslationUnitDecl());
55    if (const auto *MethodDecl = dyn_cast<CXXMethodDecl>(FoundDecl)) {
56      addUSRsOfOverridenFunctions(MethodDecl);
57      for (const auto &OverriddenMethod : OverriddenMethods) {
58        if (checkIfOverriddenFunctionAscends(OverriddenMethod))
59          USRSet.insert(getUSRForDecl(OverriddenMethod));
60      }
61    } else if (const auto *RecordDecl = dyn_cast<CXXRecordDecl>(FoundDecl)) {
62      handleCXXRecordDecl(RecordDecl);
63    } else if (const auto *TemplateDecl =
64                   dyn_cast<ClassTemplateDecl>(FoundDecl)) {
65      handleClassTemplateDecl(TemplateDecl);
66    } else {
67      USRSet.insert(getUSRForDecl(FoundDecl));
68    }
69    return std::vector<std::string>(USRSet.begin(), USRSet.end());
70  }
71
72  bool VisitCXXMethodDecl(const CXXMethodDecl *MethodDecl) {
73    if (MethodDecl->isVirtual())
74      OverriddenMethods.push_back(MethodDecl);
75    return true;
76  }
77
78  bool VisitClassTemplatePartialSpecializationDecl(
79      const ClassTemplatePartialSpecializationDecl *PartialSpec) {
80    PartialSpecs.push_back(PartialSpec);
81    return true;
82  }
83
84private:
85  void handleCXXRecordDecl(const CXXRecordDecl *RecordDecl) {
86    RecordDecl = RecordDecl->getDefinition();
87    if (const auto *ClassTemplateSpecDecl =
88            dyn_cast<ClassTemplateSpecializationDecl>(RecordDecl))
89      handleClassTemplateDecl(ClassTemplateSpecDecl->getSpecializedTemplate());
90    addUSRsOfCtorDtors(RecordDecl);
91  }
92
93  void handleClassTemplateDecl(const ClassTemplateDecl *TemplateDecl) {
94    for (const auto *Specialization : TemplateDecl->specializations())
95      addUSRsOfCtorDtors(Specialization);
96
97    for (const auto *PartialSpec : PartialSpecs) {
98      if (PartialSpec->getSpecializedTemplate() == TemplateDecl)
99        addUSRsOfCtorDtors(PartialSpec);
100    }
101    addUSRsOfCtorDtors(TemplateDecl->getTemplatedDecl());
102  }
103
104  void addUSRsOfCtorDtors(const CXXRecordDecl *RecordDecl) {
105    RecordDecl = RecordDecl->getDefinition();
106
107    // Skip if the CXXRecordDecl doesn't have definition.
108    if (!RecordDecl)
109      return;
110
111    for (const auto *CtorDecl : RecordDecl->ctors())
112      USRSet.insert(getUSRForDecl(CtorDecl));
113
114    USRSet.insert(getUSRForDecl(RecordDecl->getDestructor()));
115    USRSet.insert(getUSRForDecl(RecordDecl));
116  }
117
118  void addUSRsOfOverridenFunctions(const CXXMethodDecl *MethodDecl) {
119    USRSet.insert(getUSRForDecl(MethodDecl));
120    // Recursively visit each OverridenMethod.
121    for (const auto &OverriddenMethod : MethodDecl->overridden_methods())
122      addUSRsOfOverridenFunctions(OverriddenMethod);
123  }
124
125  bool checkIfOverriddenFunctionAscends(const CXXMethodDecl *MethodDecl) {
126    for (const auto &OverriddenMethod : MethodDecl->overridden_methods()) {
127      if (USRSet.find(getUSRForDecl(OverriddenMethod)) != USRSet.end())
128        return true;
129      return checkIfOverriddenFunctionAscends(OverriddenMethod);
130    }
131    return false;
132  }
133
134  const Decl *FoundDecl;
135  ASTContext &Context;
136  std::set<std::string> USRSet;
137  std::vector<const CXXMethodDecl *> OverriddenMethods;
138  std::vector<const ClassTemplatePartialSpecializationDecl *> PartialSpecs;
139};
140} // namespace
141
142class NamedDeclFindingConsumer : public ASTConsumer {
143public:
144  NamedDeclFindingConsumer(ArrayRef<unsigned> SymbolOffsets,
145                           ArrayRef<std::string> QualifiedNames,
146                           std::vector<std::string> &SpellingNames,
147                           std::vector<std::vector<std::string>> &USRList,
148                           bool Force, bool &ErrorOccurred)
149      : SymbolOffsets(SymbolOffsets), QualifiedNames(QualifiedNames),
150        SpellingNames(SpellingNames), USRList(USRList), Force(Force),
151        ErrorOccurred(ErrorOccurred) {}
152
153private:
154  bool FindSymbol(ASTContext &Context, const SourceManager &SourceMgr,
155                  unsigned SymbolOffset, const std::string &QualifiedName) {
156    DiagnosticsEngine &Engine = Context.getDiagnostics();
157    const FileID MainFileID = SourceMgr.getMainFileID();
158
159    if (SymbolOffset >= SourceMgr.getFileIDSize(MainFileID)) {
160      ErrorOccurred = true;
161      unsigned InvalidOffset = Engine.getCustomDiagID(
162          DiagnosticsEngine::Error,
163          "SourceLocation in file %0 at offset %1 is invalid");
164      Engine.Report(SourceLocation(), InvalidOffset)
165          << SourceMgr.getFileEntryForID(MainFileID)->getName() << SymbolOffset;
166      return false;
167    }
168
169    const SourceLocation Point = SourceMgr.getLocForStartOfFile(MainFileID)
170                                     .getLocWithOffset(SymbolOffset);
171    const NamedDecl *FoundDecl = QualifiedName.empty()
172                                     ? getNamedDeclAt(Context, Point)
173                                     : getNamedDeclFor(Context, QualifiedName);
174
175    if (FoundDecl == nullptr) {
176      if (QualifiedName.empty()) {
177        FullSourceLoc FullLoc(Point, SourceMgr);
178        unsigned CouldNotFindSymbolAt = Engine.getCustomDiagID(
179            DiagnosticsEngine::Error,
180            "clang-rename could not find symbol (offset %0)");
181        Engine.Report(Point, CouldNotFindSymbolAt) << SymbolOffset;
182        ErrorOccurred = true;
183        return false;
184      }
185
186      if (Force)
187        return true;
188
189      unsigned CouldNotFindSymbolNamed = Engine.getCustomDiagID(
190          DiagnosticsEngine::Error, "clang-rename could not find symbol %0");
191      Engine.Report(CouldNotFindSymbolNamed) << QualifiedName;
192      ErrorOccurred = true;
193      return false;
194    }
195
196    // If FoundDecl is a constructor or destructor, we want to instead take
197    // the Decl of the corresponding class.
198    if (const auto *CtorDecl = dyn_cast<CXXConstructorDecl>(FoundDecl))
199      FoundDecl = CtorDecl->getParent();
200    else if (const auto *DtorDecl = dyn_cast<CXXDestructorDecl>(FoundDecl))
201      FoundDecl = DtorDecl->getParent();
202
203    SpellingNames.push_back(FoundDecl->getNameAsString());
204    AdditionalUSRFinder Finder(FoundDecl, Context);
205    USRList.push_back(Finder.Find());
206    return true;
207  }
208
209  void HandleTranslationUnit(ASTContext &Context) override {
210    const SourceManager &SourceMgr = Context.getSourceManager();
211    for (unsigned Offset : SymbolOffsets) {
212      if (!FindSymbol(Context, SourceMgr, Offset, ""))
213        return;
214    }
215    for (const std::string &QualifiedName : QualifiedNames) {
216      if (!FindSymbol(Context, SourceMgr, 0, QualifiedName))
217        return;
218    }
219  }
220
221  ArrayRef<unsigned> SymbolOffsets;
222  ArrayRef<std::string> QualifiedNames;
223  std::vector<std::string> &SpellingNames;
224  std::vector<std::vector<std::string>> &USRList;
225  bool Force;
226  bool &ErrorOccurred;
227};
228
229std::unique_ptr<ASTConsumer> USRFindingAction::newASTConsumer() {
230  return llvm::make_unique<NamedDeclFindingConsumer>(
231      SymbolOffsets, QualifiedNames, SpellingNames, USRList, Force,
232      ErrorOccurred);
233}
234
235} // end namespace tooling
236} // end namespace clang
237