1//===--- Refactoring.cpp - Framework for clang refactoring tools ----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10//  Implements tools to support refactorings.
11//
12//===----------------------------------------------------------------------===//
13
14#include "clang/Basic/DiagnosticOptions.h"
15#include "clang/Basic/FileManager.h"
16#include "clang/Basic/SourceManager.h"
17#include "clang/Frontend/TextDiagnosticPrinter.h"
18#include "clang/Lex/Lexer.h"
19#include "clang/Rewrite/Core/Rewriter.h"
20#include "clang/Tooling/Refactoring.h"
21#include "llvm/Support/raw_os_ostream.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/Path.h"
24
25namespace clang {
26namespace tooling {
27
28static const char * const InvalidLocation = "";
29
30Replacement::Replacement()
31  : FilePath(InvalidLocation) {}
32
33Replacement::Replacement(StringRef FilePath, unsigned Offset, unsigned Length,
34                         StringRef ReplacementText)
35    : FilePath(FilePath), ReplacementRange(Offset, Length),
36      ReplacementText(ReplacementText) {}
37
38Replacement::Replacement(SourceManager &Sources, SourceLocation Start,
39                         unsigned Length, StringRef ReplacementText) {
40  setFromSourceLocation(Sources, Start, Length, ReplacementText);
41}
42
43Replacement::Replacement(SourceManager &Sources, const CharSourceRange &Range,
44                         StringRef ReplacementText) {
45  setFromSourceRange(Sources, Range, ReplacementText);
46}
47
48bool Replacement::isApplicable() const {
49  return FilePath != InvalidLocation;
50}
51
52bool Replacement::apply(Rewriter &Rewrite) const {
53  SourceManager &SM = Rewrite.getSourceMgr();
54  const FileEntry *Entry = SM.getFileManager().getFile(FilePath);
55  if (Entry == NULL)
56    return false;
57  FileID ID;
58  // FIXME: Use SM.translateFile directly.
59  SourceLocation Location = SM.translateFileLineCol(Entry, 1, 1);
60  ID = Location.isValid() ?
61    SM.getFileID(Location) :
62    SM.createFileID(Entry, SourceLocation(), SrcMgr::C_User);
63  // FIXME: We cannot check whether Offset + Length is in the file, as
64  // the remapping API is not public in the RewriteBuffer.
65  const SourceLocation Start =
66    SM.getLocForStartOfFile(ID).
67    getLocWithOffset(ReplacementRange.getOffset());
68  // ReplaceText returns false on success.
69  // ReplaceText only fails if the source location is not a file location, in
70  // which case we already returned false earlier.
71  bool RewriteSucceeded = !Rewrite.ReplaceText(
72      Start, ReplacementRange.getLength(), ReplacementText);
73  assert(RewriteSucceeded);
74  return RewriteSucceeded;
75}
76
77std::string Replacement::toString() const {
78  std::string result;
79  llvm::raw_string_ostream stream(result);
80  stream << FilePath << ": " << ReplacementRange.getOffset() << ":+"
81         << ReplacementRange.getLength() << ":\"" << ReplacementText << "\"";
82  return result;
83}
84
85bool operator<(const Replacement &LHS, const Replacement &RHS) {
86  if (LHS.getOffset() != RHS.getOffset())
87    return LHS.getOffset() < RHS.getOffset();
88  if (LHS.getLength() != RHS.getLength())
89    return LHS.getLength() < RHS.getLength();
90  if (LHS.getFilePath() != RHS.getFilePath())
91    return LHS.getFilePath() < RHS.getFilePath();
92  return LHS.getReplacementText() < RHS.getReplacementText();
93}
94
95bool operator==(const Replacement &LHS, const Replacement &RHS) {
96  return LHS.getOffset() == RHS.getOffset() &&
97         LHS.getLength() == RHS.getLength() &&
98         LHS.getFilePath() == RHS.getFilePath() &&
99         LHS.getReplacementText() == RHS.getReplacementText();
100}
101
102void Replacement::setFromSourceLocation(SourceManager &Sources,
103                                        SourceLocation Start, unsigned Length,
104                                        StringRef ReplacementText) {
105  const std::pair<FileID, unsigned> DecomposedLocation =
106      Sources.getDecomposedLoc(Start);
107  const FileEntry *Entry = Sources.getFileEntryForID(DecomposedLocation.first);
108  if (Entry != NULL) {
109    // Make FilePath absolute so replacements can be applied correctly when
110    // relative paths for files are used.
111    llvm::SmallString<256> FilePath(Entry->getName());
112    llvm::error_code EC = llvm::sys::fs::make_absolute(FilePath);
113    this->FilePath = EC ? FilePath.c_str() : Entry->getName();
114  } else {
115    this->FilePath = InvalidLocation;
116  }
117  this->ReplacementRange = Range(DecomposedLocation.second, Length);
118  this->ReplacementText = ReplacementText;
119}
120
121// FIXME: This should go into the Lexer, but we need to figure out how
122// to handle ranges for refactoring in general first - there is no obvious
123// good way how to integrate this into the Lexer yet.
124static int getRangeSize(SourceManager &Sources, const CharSourceRange &Range) {
125  SourceLocation SpellingBegin = Sources.getSpellingLoc(Range.getBegin());
126  SourceLocation SpellingEnd = Sources.getSpellingLoc(Range.getEnd());
127  std::pair<FileID, unsigned> Start = Sources.getDecomposedLoc(SpellingBegin);
128  std::pair<FileID, unsigned> End = Sources.getDecomposedLoc(SpellingEnd);
129  if (Start.first != End.first) return -1;
130  if (Range.isTokenRange())
131    End.second += Lexer::MeasureTokenLength(SpellingEnd, Sources,
132                                            LangOptions());
133  return End.second - Start.second;
134}
135
136void Replacement::setFromSourceRange(SourceManager &Sources,
137                                     const CharSourceRange &Range,
138                                     StringRef ReplacementText) {
139  setFromSourceLocation(Sources, Sources.getSpellingLoc(Range.getBegin()),
140                        getRangeSize(Sources, Range), ReplacementText);
141}
142
143bool applyAllReplacements(const Replacements &Replaces, Rewriter &Rewrite) {
144  bool Result = true;
145  for (Replacements::const_iterator I = Replaces.begin(),
146                                    E = Replaces.end();
147       I != E; ++I) {
148    if (I->isApplicable()) {
149      Result = I->apply(Rewrite) && Result;
150    } else {
151      Result = false;
152    }
153  }
154  return Result;
155}
156
157// FIXME: Remove this function when Replacements is implemented as std::vector
158// instead of std::set.
159bool applyAllReplacements(const std::vector<Replacement> &Replaces,
160                          Rewriter &Rewrite) {
161  bool Result = true;
162  for (std::vector<Replacement>::const_iterator I = Replaces.begin(),
163                                                E = Replaces.end();
164       I != E; ++I) {
165    if (I->isApplicable()) {
166      Result = I->apply(Rewrite) && Result;
167    } else {
168      Result = false;
169    }
170  }
171  return Result;
172}
173
174std::string applyAllReplacements(StringRef Code, const Replacements &Replaces) {
175  FileManager Files((FileSystemOptions()));
176  DiagnosticsEngine Diagnostics(
177      IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs),
178      new DiagnosticOptions);
179  Diagnostics.setClient(new TextDiagnosticPrinter(
180      llvm::outs(), &Diagnostics.getDiagnosticOptions()));
181  SourceManager SourceMgr(Diagnostics, Files);
182  Rewriter Rewrite(SourceMgr, LangOptions());
183  llvm::MemoryBuffer *Buf = llvm::MemoryBuffer::getMemBuffer(Code, "<stdin>");
184  const clang::FileEntry *Entry =
185      Files.getVirtualFile("<stdin>", Buf->getBufferSize(), 0);
186  SourceMgr.overrideFileContents(Entry, Buf);
187  FileID ID =
188      SourceMgr.createFileID(Entry, SourceLocation(), clang::SrcMgr::C_User);
189  for (Replacements::const_iterator I = Replaces.begin(), E = Replaces.end();
190       I != E; ++I) {
191    Replacement Replace("<stdin>", I->getOffset(), I->getLength(),
192                        I->getReplacementText());
193    if (!Replace.apply(Rewrite))
194      return "";
195  }
196  std::string Result;
197  llvm::raw_string_ostream OS(Result);
198  Rewrite.getEditBuffer(ID).write(OS);
199  OS.flush();
200  return Result;
201}
202
203unsigned shiftedCodePosition(const Replacements &Replaces, unsigned Position) {
204  unsigned NewPosition = Position;
205  for (Replacements::iterator I = Replaces.begin(), E = Replaces.end(); I != E;
206       ++I) {
207    if (I->getOffset() >= Position)
208      break;
209    if (I->getOffset() + I->getLength() > Position)
210      NewPosition += I->getOffset() + I->getLength() - Position;
211    NewPosition += I->getReplacementText().size() - I->getLength();
212  }
213  return NewPosition;
214}
215
216// FIXME: Remove this function when Replacements is implemented as std::vector
217// instead of std::set.
218unsigned shiftedCodePosition(const std::vector<Replacement> &Replaces,
219                             unsigned Position) {
220  unsigned NewPosition = Position;
221  for (std::vector<Replacement>::const_iterator I = Replaces.begin(),
222                                                E = Replaces.end();
223       I != E; ++I) {
224    if (I->getOffset() >= Position)
225      break;
226    if (I->getOffset() + I->getLength() > Position)
227      NewPosition += I->getOffset() + I->getLength() - Position;
228    NewPosition += I->getReplacementText().size() - I->getLength();
229  }
230  return NewPosition;
231}
232
233void deduplicate(std::vector<Replacement> &Replaces,
234                 std::vector<Range> &Conflicts) {
235  if (Replaces.empty())
236    return;
237
238  // Deduplicate
239  std::sort(Replaces.begin(), Replaces.end());
240  std::vector<Replacement>::iterator End =
241      std::unique(Replaces.begin(), Replaces.end());
242  Replaces.erase(End, Replaces.end());
243
244  // Detect conflicts
245  Range ConflictRange(Replaces.front().getOffset(),
246                      Replaces.front().getLength());
247  unsigned ConflictStart = 0;
248  unsigned ConflictLength = 1;
249  for (unsigned i = 1; i < Replaces.size(); ++i) {
250    Range Current(Replaces[i].getOffset(), Replaces[i].getLength());
251    if (ConflictRange.overlapsWith(Current)) {
252      // Extend conflicted range
253      ConflictRange = Range(ConflictRange.getOffset(),
254                            std::max(ConflictRange.getLength(),
255                                     Current.getOffset() + Current.getLength() -
256                                         ConflictRange.getOffset()));
257      ++ConflictLength;
258    } else {
259      if (ConflictLength > 1)
260        Conflicts.push_back(Range(ConflictStart, ConflictLength));
261      ConflictRange = Current;
262      ConflictStart = i;
263      ConflictLength = 1;
264    }
265  }
266
267  if (ConflictLength > 1)
268    Conflicts.push_back(Range(ConflictStart, ConflictLength));
269}
270
271
272RefactoringTool::RefactoringTool(const CompilationDatabase &Compilations,
273                                 ArrayRef<std::string> SourcePaths)
274  : ClangTool(Compilations, SourcePaths) {}
275
276Replacements &RefactoringTool::getReplacements() { return Replace; }
277
278int RefactoringTool::runAndSave(FrontendActionFactory *ActionFactory) {
279  if (int Result = run(ActionFactory)) {
280    return Result;
281  }
282
283  LangOptions DefaultLangOptions;
284  IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts = new DiagnosticOptions();
285  TextDiagnosticPrinter DiagnosticPrinter(llvm::errs(), &*DiagOpts);
286  DiagnosticsEngine Diagnostics(
287      IntrusiveRefCntPtr<DiagnosticIDs>(new DiagnosticIDs()),
288      &*DiagOpts, &DiagnosticPrinter, false);
289  SourceManager Sources(Diagnostics, getFiles());
290  Rewriter Rewrite(Sources, DefaultLangOptions);
291
292  if (!applyAllReplacements(Rewrite)) {
293    llvm::errs() << "Skipped some replacements.\n";
294  }
295
296  return saveRewrittenFiles(Rewrite);
297}
298
299bool RefactoringTool::applyAllReplacements(Rewriter &Rewrite) {
300  return tooling::applyAllReplacements(Replace, Rewrite);
301}
302
303int RefactoringTool::saveRewrittenFiles(Rewriter &Rewrite) {
304  return Rewrite.overwriteChangedFiles() ? 1 : 0;
305}
306
307} // end namespace tooling
308} // end namespace clang
309