1//===--- RefactoringCallbacks.cpp - Structural query framework ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9//
10//===----------------------------------------------------------------------===//
11#include "clang/Tooling/RefactoringCallbacks.h"
12#include "clang/ASTMatchers/ASTMatchFinder.h"
13#include "clang/Basic/SourceLocation.h"
14#include "clang/Lex/Lexer.h"
15
16using llvm::StringError;
17using llvm::make_error;
18
19namespace clang {
20namespace tooling {
21
22RefactoringCallback::RefactoringCallback() {}
23tooling::Replacements &RefactoringCallback::getReplacements() {
24  return Replace;
25}
26
27ASTMatchRefactorer::ASTMatchRefactorer(
28    std::map<std::string, Replacements> &FileToReplaces)
29    : FileToReplaces(FileToReplaces) {}
30
31void ASTMatchRefactorer::addDynamicMatcher(
32    const ast_matchers::internal::DynTypedMatcher &Matcher,
33    RefactoringCallback *Callback) {
34  MatchFinder.addDynamicMatcher(Matcher, Callback);
35  Callbacks.push_back(Callback);
36}
37
38class RefactoringASTConsumer : public ASTConsumer {
39public:
40  explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring)
41      : Refactoring(Refactoring) {}
42
43  void HandleTranslationUnit(ASTContext &Context) override {
44    // The ASTMatchRefactorer is re-used between translation units.
45    // Clear the matchers so that each Replacement is only emitted once.
46    for (const auto &Callback : Refactoring.Callbacks) {
47      Callback->getReplacements().clear();
48    }
49    Refactoring.MatchFinder.matchAST(Context);
50    for (const auto &Callback : Refactoring.Callbacks) {
51      for (const auto &Replacement : Callback->getReplacements()) {
52        llvm::Error Err =
53            Refactoring.FileToReplaces[Replacement.getFilePath()].add(
54                Replacement);
55        if (Err) {
56          llvm::errs() << "Skipping replacement " << Replacement.toString()
57                       << " due to this error:\n"
58                       << toString(std::move(Err)) << "\n";
59        }
60      }
61    }
62  }
63
64private:
65  ASTMatchRefactorer &Refactoring;
66};
67
68std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() {
69  return std::make_unique<RefactoringASTConsumer>(*this);
70}
71
72static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From,
73                                       StringRef Text) {
74  return tooling::Replacement(
75      Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text);
76}
77static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From,
78                                       const Stmt &To) {
79  return replaceStmtWithText(
80      Sources, From,
81      Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()),
82                           Sources, LangOptions()));
83}
84
85ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText)
86    : FromId(FromId), ToText(ToText) {}
87
88void ReplaceStmtWithText::run(
89    const ast_matchers::MatchFinder::MatchResult &Result) {
90  if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) {
91    auto Err = Replace.add(tooling::Replacement(
92        *Result.SourceManager,
93        CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText));
94    // FIXME: better error handling. For now, just print error message in the
95    // release version.
96    if (Err) {
97      llvm::errs() << llvm::toString(std::move(Err)) << "\n";
98      assert(false);
99    }
100  }
101}
102
103ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId)
104    : FromId(FromId), ToId(ToId) {}
105
106void ReplaceStmtWithStmt::run(
107    const ast_matchers::MatchFinder::MatchResult &Result) {
108  const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId);
109  const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId);
110  if (FromMatch && ToMatch) {
111    auto Err = Replace.add(
112        replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch));
113    // FIXME: better error handling. For now, just print error message in the
114    // release version.
115    if (Err) {
116      llvm::errs() << llvm::toString(std::move(Err)) << "\n";
117      assert(false);
118    }
119  }
120}
121
122ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id,
123                                                   bool PickTrueBranch)
124    : Id(Id), PickTrueBranch(PickTrueBranch) {}
125
126void ReplaceIfStmtWithItsBody::run(
127    const ast_matchers::MatchFinder::MatchResult &Result) {
128  if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) {
129    const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse();
130    if (Body) {
131      auto Err =
132          Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body));
133      // FIXME: better error handling. For now, just print error message in the
134      // release version.
135      if (Err) {
136        llvm::errs() << llvm::toString(std::move(Err)) << "\n";
137        assert(false);
138      }
139    } else if (!PickTrueBranch) {
140      // If we want to use the 'else'-branch, but it doesn't exist, delete
141      // the whole 'if'.
142      auto Err =
143          Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, ""));
144      // FIXME: better error handling. For now, just print error message in the
145      // release version.
146      if (Err) {
147        llvm::errs() << llvm::toString(std::move(Err)) << "\n";
148        assert(false);
149      }
150    }
151  }
152}
153
154ReplaceNodeWithTemplate::ReplaceNodeWithTemplate(
155    llvm::StringRef FromId, std::vector<TemplateElement> Template)
156    : FromId(FromId), Template(std::move(Template)) {}
157
158llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>>
159ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) {
160  std::vector<TemplateElement> ParsedTemplate;
161  for (size_t Index = 0; Index < ToTemplate.size();) {
162    if (ToTemplate[Index] == '$') {
163      if (ToTemplate.substr(Index, 2) == "$$") {
164        Index += 2;
165        ParsedTemplate.push_back(
166            TemplateElement{TemplateElement::Literal, "$"});
167      } else if (ToTemplate.substr(Index, 2) == "${") {
168        size_t EndOfIdentifier = ToTemplate.find("}", Index);
169        if (EndOfIdentifier == std::string::npos) {
170          return make_error<StringError>(
171              "Unterminated ${...} in replacement template near " +
172                  ToTemplate.substr(Index),
173              llvm::inconvertibleErrorCode());
174        }
175        std::string SourceNodeName =
176            ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2);
177        ParsedTemplate.push_back(
178            TemplateElement{TemplateElement::Identifier, SourceNodeName});
179        Index = EndOfIdentifier + 1;
180      } else {
181        return make_error<StringError>(
182            "Invalid $ in replacement template near " +
183                ToTemplate.substr(Index),
184            llvm::inconvertibleErrorCode());
185      }
186    } else {
187      size_t NextIndex = ToTemplate.find('$', Index + 1);
188      ParsedTemplate.push_back(
189          TemplateElement{TemplateElement::Literal,
190                          ToTemplate.substr(Index, NextIndex - Index)});
191      Index = NextIndex;
192    }
193  }
194  return std::unique_ptr<ReplaceNodeWithTemplate>(
195      new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate)));
196}
197
198void ReplaceNodeWithTemplate::run(
199    const ast_matchers::MatchFinder::MatchResult &Result) {
200  const auto &NodeMap = Result.Nodes.getMap();
201
202  std::string ToText;
203  for (const auto &Element : Template) {
204    switch (Element.Type) {
205    case TemplateElement::Literal:
206      ToText += Element.Value;
207      break;
208    case TemplateElement::Identifier: {
209      auto NodeIter = NodeMap.find(Element.Value);
210      if (NodeIter == NodeMap.end()) {
211        llvm::errs() << "Node " << Element.Value
212                     << " used in replacement template not bound in Matcher \n";
213        llvm::report_fatal_error("Unbound node in replacement template.");
214      }
215      CharSourceRange Source =
216          CharSourceRange::getTokenRange(NodeIter->second.getSourceRange());
217      ToText += Lexer::getSourceText(Source, *Result.SourceManager,
218                                     Result.Context->getLangOpts());
219      break;
220    }
221    }
222  }
223  if (NodeMap.count(FromId) == 0) {
224    llvm::errs() << "Node to be replaced " << FromId
225                 << " not bound in query.\n";
226    llvm::report_fatal_error("FromId node not bound in MatchResult");
227  }
228  auto Replacement =
229      tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText,
230                           Result.Context->getLangOpts());
231  llvm::Error Err = Replace.add(Replacement);
232  if (Err) {
233    llvm::errs() << "Query and replace failed in " << Replacement.getFilePath()
234                 << "! " << llvm::toString(std::move(Err)) << "\n";
235    llvm::report_fatal_error("Replacement failed");
236  }
237}
238
239} // end namespace tooling
240} // end namespace clang
241