1//===- ClangDiff.cpp - compare source files by AST nodes ------*- C++ -*- -===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a tool for syntax tree based comparison using
10// Tooling/ASTDiff.
11//
12//===----------------------------------------------------------------------===//
13
14#include "clang/Tooling/ASTDiff/ASTDiff.h"
15#include "clang/Tooling/CommonOptionsParser.h"
16#include "clang/Tooling/Tooling.h"
17#include "llvm/Support/CommandLine.h"
18
19using namespace llvm;
20using namespace clang;
21using namespace clang::tooling;
22
23static cl::OptionCategory ClangDiffCategory("clang-diff options");
24
25static cl::opt<bool>
26    ASTDump("ast-dump",
27            cl::desc("Print the internal representation of the AST."),
28            cl::init(false), cl::cat(ClangDiffCategory));
29
30static cl::opt<bool> ASTDumpJson(
31    "ast-dump-json",
32    cl::desc("Print the internal representation of the AST as JSON."),
33    cl::init(false), cl::cat(ClangDiffCategory));
34
35static cl::opt<bool> PrintMatches("dump-matches",
36                                  cl::desc("Print the matched nodes."),
37                                  cl::init(false), cl::cat(ClangDiffCategory));
38
39static cl::opt<bool> HtmlDiff("html",
40                              cl::desc("Output a side-by-side diff in HTML."),
41                              cl::init(false), cl::cat(ClangDiffCategory));
42
43static cl::opt<std::string> SourcePath(cl::Positional, cl::desc("<source>"),
44                                       cl::Required,
45                                       cl::cat(ClangDiffCategory));
46
47static cl::opt<std::string> DestinationPath(cl::Positional,
48                                            cl::desc("<destination>"),
49                                            cl::Optional,
50                                            cl::cat(ClangDiffCategory));
51
52static cl::opt<std::string> StopAfter("stop-diff-after",
53                                      cl::desc("<topdown|bottomup>"),
54                                      cl::Optional, cl::init(""),
55                                      cl::cat(ClangDiffCategory));
56
57static cl::opt<int> MaxSize("s", cl::desc("<maxsize>"), cl::Optional,
58                            cl::init(-1), cl::cat(ClangDiffCategory));
59
60static cl::opt<std::string> BuildPath("p", cl::desc("Build path"), cl::init(""),
61                                      cl::Optional, cl::cat(ClangDiffCategory));
62
63static cl::list<std::string> ArgsAfter(
64    "extra-arg",
65    cl::desc("Additional argument to append to the compiler command line"),
66    cl::cat(ClangDiffCategory));
67
68static cl::list<std::string> ArgsBefore(
69    "extra-arg-before",
70    cl::desc("Additional argument to prepend to the compiler command line"),
71    cl::cat(ClangDiffCategory));
72
73static void addExtraArgs(std::unique_ptr<CompilationDatabase> &Compilations) {
74  if (!Compilations)
75    return;
76  auto AdjustingCompilations =
77      std::make_unique<ArgumentsAdjustingCompilations>(
78          std::move(Compilations));
79  AdjustingCompilations->appendArgumentsAdjuster(
80      getInsertArgumentAdjuster(ArgsBefore, ArgumentInsertPosition::BEGIN));
81  AdjustingCompilations->appendArgumentsAdjuster(
82      getInsertArgumentAdjuster(ArgsAfter, ArgumentInsertPosition::END));
83  Compilations = std::move(AdjustingCompilations);
84}
85
86static std::unique_ptr<ASTUnit>
87getAST(const std::unique_ptr<CompilationDatabase> &CommonCompilations,
88       const StringRef Filename) {
89  std::string ErrorMessage;
90  std::unique_ptr<CompilationDatabase> Compilations;
91  if (!CommonCompilations) {
92    Compilations = CompilationDatabase::autoDetectFromSource(
93        BuildPath.empty() ? Filename : BuildPath, ErrorMessage);
94    if (!Compilations) {
95      llvm::errs()
96          << "Error while trying to load a compilation database, running "
97             "without flags.\n"
98          << ErrorMessage;
99      Compilations =
100          std::make_unique<clang::tooling::FixedCompilationDatabase>(
101              ".", std::vector<std::string>());
102    }
103  }
104  addExtraArgs(Compilations);
105  std::array<std::string, 1> Files = {{std::string(Filename)}};
106  ClangTool Tool(Compilations ? *Compilations : *CommonCompilations, Files);
107  std::vector<std::unique_ptr<ASTUnit>> ASTs;
108  Tool.buildASTs(ASTs);
109  if (ASTs.size() != Files.size())
110    return nullptr;
111  return std::move(ASTs[0]);
112}
113
114static char hexdigit(int N) { return N &= 0xf, N + (N < 10 ? '0' : 'a' - 10); }
115
116static const char HtmlDiffHeader[] = R"(
117<html>
118<head>
119<meta charset='utf-8'/>
120<style>
121span.d { color: red; }
122span.u { color: #cc00cc; }
123span.i { color: green; }
124span.m { font-weight: bold; }
125span   { font-weight: normal; color: black; }
126div.code {
127  width: 48%;
128  height: 98%;
129  overflow: scroll;
130  float: left;
131  padding: 0 0 0.5% 0.5%;
132  border: solid 2px LightGrey;
133  border-radius: 5px;
134}
135</style>
136</head>
137<script type='text/javascript'>
138highlightStack = []
139function clearHighlight() {
140  while (highlightStack.length) {
141    var [l, r] = highlightStack.pop()
142    document.getElementById(l).style.backgroundColor = 'inherit'
143    if (r[1] != '-')
144      document.getElementById(r).style.backgroundColor = 'inherit'
145  }
146}
147function highlight(event) {
148  var id = event.target['id']
149  doHighlight(id)
150}
151function doHighlight(id) {
152  clearHighlight()
153  source = document.getElementById(id)
154  if (!source.attributes['tid'])
155    return
156  var mapped = source
157  while (mapped && mapped.parentElement && mapped.attributes['tid'].value.substr(1) === '-1')
158    mapped = mapped.parentElement
159  var tid = null, target = null
160  if (mapped) {
161    tid = mapped.attributes['tid'].value
162    target = document.getElementById(tid)
163  }
164  if (source.parentElement && source.parentElement.classList.contains('code'))
165    return
166  source.style.backgroundColor = 'lightgrey'
167  source.scrollIntoView()
168  if (target) {
169    if (mapped === source)
170      target.style.backgroundColor = 'lightgrey'
171    target.scrollIntoView()
172  }
173  highlightStack.push([id, tid])
174  location.hash = '#' + id
175}
176function scrollToBoth() {
177  doHighlight(location.hash.substr(1))
178}
179function changed(elem) {
180  return elem.classList.length == 0
181}
182function nextChangedNode(prefix, increment, number) {
183  do {
184    number += increment
185    var elem = document.getElementById(prefix + number)
186  } while(elem && !changed(elem))
187  return elem ? number : null
188}
189function handleKey(e) {
190  var down = e.code === "KeyJ"
191  var up = e.code === "KeyK"
192  if (!down && !up)
193    return
194  var id = highlightStack[0] ? highlightStack[0][0] : 'R0'
195  var oldelem = document.getElementById(id)
196  var number = parseInt(id.substr(1))
197  var increment = down ? 1 : -1
198  var lastnumber = number
199  var prefix = id[0]
200  do {
201    number = nextChangedNode(prefix, increment, number)
202    var elem = document.getElementById(prefix + number)
203    if (up && elem) {
204      while (elem.parentElement && changed(elem.parentElement))
205        elem = elem.parentElement
206      number = elem.id.substr(1)
207    }
208  } while ((down && id !== 'R0' && oldelem.contains(elem)))
209  if (!number)
210    number = lastnumber
211  elem = document.getElementById(prefix + number)
212  doHighlight(prefix + number)
213}
214window.onload = scrollToBoth
215window.onkeydown = handleKey
216</script>
217<body>
218<div onclick='highlight(event)'>
219)";
220
221static void printHtml(raw_ostream &OS, char C) {
222  switch (C) {
223  case '&':
224    OS << "&amp;";
225    break;
226  case '<':
227    OS << "&lt;";
228    break;
229  case '>':
230    OS << "&gt;";
231    break;
232  case '\'':
233    OS << "&#x27;";
234    break;
235  case '"':
236    OS << "&quot;";
237    break;
238  default:
239    OS << C;
240  }
241}
242
243static void printHtml(raw_ostream &OS, const StringRef Str) {
244  for (char C : Str)
245    printHtml(OS, C);
246}
247
248static std::string getChangeKindAbbr(diff::ChangeKind Kind) {
249  switch (Kind) {
250  case diff::None:
251    return "";
252  case diff::Delete:
253    return "d";
254  case diff::Update:
255    return "u";
256  case diff::Insert:
257    return "i";
258  case diff::Move:
259    return "m";
260  case diff::UpdateMove:
261    return "u m";
262  }
263  llvm_unreachable("Invalid enumeration value.");
264}
265
266static unsigned printHtmlForNode(raw_ostream &OS, const diff::ASTDiff &Diff,
267                                 diff::SyntaxTree &Tree, bool IsLeft,
268                                 diff::NodeId Id, unsigned Offset) {
269  const diff::Node &Node = Tree.getNode(Id);
270  char MyTag, OtherTag;
271  diff::NodeId LeftId, RightId;
272  diff::NodeId TargetId = Diff.getMapped(Tree, Id);
273  if (IsLeft) {
274    MyTag = 'L';
275    OtherTag = 'R';
276    LeftId = Id;
277    RightId = TargetId;
278  } else {
279    MyTag = 'R';
280    OtherTag = 'L';
281    LeftId = TargetId;
282    RightId = Id;
283  }
284  unsigned Begin, End;
285  std::tie(Begin, End) = Tree.getSourceRangeOffsets(Node);
286  const SourceManager &SrcMgr = Tree.getASTContext().getSourceManager();
287  auto Code = SrcMgr.getBufferOrFake(SrcMgr.getMainFileID()).getBuffer();
288  for (; Offset < Begin; ++Offset)
289    printHtml(OS, Code[Offset]);
290  OS << "<span id='" << MyTag << Id << "' "
291     << "tid='" << OtherTag << TargetId << "' ";
292  OS << "title='";
293  printHtml(OS, Node.getTypeLabel());
294  OS << "\n" << LeftId << " -> " << RightId;
295  std::string Value = Tree.getNodeValue(Node);
296  if (!Value.empty()) {
297    OS << "\n";
298    printHtml(OS, Value);
299  }
300  OS << "'";
301  if (Node.Change != diff::None)
302    OS << " class='" << getChangeKindAbbr(Node.Change) << "'";
303  OS << ">";
304
305  for (diff::NodeId Child : Node.Children)
306    Offset = printHtmlForNode(OS, Diff, Tree, IsLeft, Child, Offset);
307
308  for (; Offset < End; ++Offset)
309    printHtml(OS, Code[Offset]);
310  if (Id == Tree.getRootId()) {
311    End = Code.size();
312    for (; Offset < End; ++Offset)
313      printHtml(OS, Code[Offset]);
314  }
315  OS << "</span>";
316  return Offset;
317}
318
319static void printJsonString(raw_ostream &OS, const StringRef Str) {
320  for (signed char C : Str) {
321    switch (C) {
322    case '"':
323      OS << R"(\")";
324      break;
325    case '\\':
326      OS << R"(\\)";
327      break;
328    case '\n':
329      OS << R"(\n)";
330      break;
331    case '\t':
332      OS << R"(\t)";
333      break;
334    default:
335      if ('\x00' <= C && C <= '\x1f') {
336        OS << R"(\u00)" << hexdigit(C >> 4) << hexdigit(C);
337      } else {
338        OS << C;
339      }
340    }
341  }
342}
343
344static void printNodeAttributes(raw_ostream &OS, diff::SyntaxTree &Tree,
345                                diff::NodeId Id) {
346  const diff::Node &N = Tree.getNode(Id);
347  OS << R"("id":)" << int(Id);
348  OS << R"(,"type":")" << N.getTypeLabel() << '"';
349  auto Offsets = Tree.getSourceRangeOffsets(N);
350  OS << R"(,"begin":)" << Offsets.first;
351  OS << R"(,"end":)" << Offsets.second;
352  std::string Value = Tree.getNodeValue(N);
353  if (!Value.empty()) {
354    OS << R"(,"value":")";
355    printJsonString(OS, Value);
356    OS << '"';
357  }
358}
359
360static void printNodeAsJson(raw_ostream &OS, diff::SyntaxTree &Tree,
361                            diff::NodeId Id) {
362  const diff::Node &N = Tree.getNode(Id);
363  OS << "{";
364  printNodeAttributes(OS, Tree, Id);
365  auto Identifier = N.getIdentifier();
366  auto QualifiedIdentifier = N.getQualifiedIdentifier();
367  if (Identifier) {
368    OS << R"(,"identifier":")";
369    printJsonString(OS, *Identifier);
370    OS << R"(")";
371    if (QualifiedIdentifier && *Identifier != *QualifiedIdentifier) {
372      OS << R"(,"qualified_identifier":")";
373      printJsonString(OS, *QualifiedIdentifier);
374      OS << R"(")";
375    }
376  }
377  OS << R"(,"children":[)";
378  if (N.Children.size() > 0) {
379    printNodeAsJson(OS, Tree, N.Children[0]);
380    for (size_t I = 1, E = N.Children.size(); I < E; ++I) {
381      OS << ",";
382      printNodeAsJson(OS, Tree, N.Children[I]);
383    }
384  }
385  OS << "]}";
386}
387
388static void printNode(raw_ostream &OS, diff::SyntaxTree &Tree,
389                      diff::NodeId Id) {
390  if (Id.isInvalid()) {
391    OS << "None";
392    return;
393  }
394  OS << Tree.getNode(Id).getTypeLabel();
395  std::string Value = Tree.getNodeValue(Id);
396  if (!Value.empty())
397    OS << ": " << Value;
398  OS << "(" << Id << ")";
399}
400
401static void printTree(raw_ostream &OS, diff::SyntaxTree &Tree) {
402  for (diff::NodeId Id : Tree) {
403    for (int I = 0; I < Tree.getNode(Id).Depth; ++I)
404      OS << " ";
405    printNode(OS, Tree, Id);
406    OS << "\n";
407  }
408}
409
410static void printDstChange(raw_ostream &OS, diff::ASTDiff &Diff,
411                           diff::SyntaxTree &SrcTree, diff::SyntaxTree &DstTree,
412                           diff::NodeId Dst) {
413  const diff::Node &DstNode = DstTree.getNode(Dst);
414  diff::NodeId Src = Diff.getMapped(DstTree, Dst);
415  switch (DstNode.Change) {
416  case diff::None:
417    break;
418  case diff::Delete:
419    llvm_unreachable("The destination tree can't have deletions.");
420  case diff::Update:
421    OS << "Update ";
422    printNode(OS, SrcTree, Src);
423    OS << " to " << DstTree.getNodeValue(Dst) << "\n";
424    break;
425  case diff::Insert:
426  case diff::Move:
427  case diff::UpdateMove:
428    if (DstNode.Change == diff::Insert)
429      OS << "Insert";
430    else if (DstNode.Change == diff::Move)
431      OS << "Move";
432    else if (DstNode.Change == diff::UpdateMove)
433      OS << "Update and Move";
434    OS << " ";
435    printNode(OS, DstTree, Dst);
436    OS << " into ";
437    printNode(OS, DstTree, DstNode.Parent);
438    OS << " at " << DstTree.findPositionInParent(Dst) << "\n";
439    break;
440  }
441}
442
443int main(int argc, const char **argv) {
444  std::string ErrorMessage;
445  std::unique_ptr<CompilationDatabase> CommonCompilations =
446      FixedCompilationDatabase::loadFromCommandLine(argc, argv, ErrorMessage);
447  if (!CommonCompilations && !ErrorMessage.empty())
448    llvm::errs() << ErrorMessage;
449  cl::HideUnrelatedOptions(ClangDiffCategory);
450  if (!cl::ParseCommandLineOptions(argc, argv)) {
451    cl::PrintOptionValues();
452    return 1;
453  }
454
455  addExtraArgs(CommonCompilations);
456
457  if (ASTDump || ASTDumpJson) {
458    if (!DestinationPath.empty()) {
459      llvm::errs() << "Error: Please specify exactly one filename.\n";
460      return 1;
461    }
462    std::unique_ptr<ASTUnit> AST = getAST(CommonCompilations, SourcePath);
463    if (!AST)
464      return 1;
465    diff::SyntaxTree Tree(AST->getASTContext());
466    if (ASTDump) {
467      printTree(llvm::outs(), Tree);
468      return 0;
469    }
470    llvm::outs() << R"({"filename":")";
471    printJsonString(llvm::outs(), SourcePath);
472    llvm::outs() << R"(","root":)";
473    printNodeAsJson(llvm::outs(), Tree, Tree.getRootId());
474    llvm::outs() << "}\n";
475    return 0;
476  }
477
478  if (DestinationPath.empty()) {
479    llvm::errs() << "Error: Exactly two paths are required.\n";
480    return 1;
481  }
482
483  std::unique_ptr<ASTUnit> Src = getAST(CommonCompilations, SourcePath);
484  std::unique_ptr<ASTUnit> Dst = getAST(CommonCompilations, DestinationPath);
485  if (!Src || !Dst)
486    return 1;
487
488  diff::ComparisonOptions Options;
489  if (MaxSize != -1)
490    Options.MaxSize = MaxSize;
491  if (!StopAfter.empty()) {
492    if (StopAfter == "topdown")
493      Options.StopAfterTopDown = true;
494    else if (StopAfter != "bottomup") {
495      llvm::errs() << "Error: Invalid argument for -stop-after\n";
496      return 1;
497    }
498  }
499  diff::SyntaxTree SrcTree(Src->getASTContext());
500  diff::SyntaxTree DstTree(Dst->getASTContext());
501  diff::ASTDiff Diff(SrcTree, DstTree, Options);
502
503  if (HtmlDiff) {
504    llvm::outs() << HtmlDiffHeader << "<pre>";
505    llvm::outs() << "<div id='L' class='code'>";
506    printHtmlForNode(llvm::outs(), Diff, SrcTree, true, SrcTree.getRootId(), 0);
507    llvm::outs() << "</div>";
508    llvm::outs() << "<div id='R' class='code'>";
509    printHtmlForNode(llvm::outs(), Diff, DstTree, false, DstTree.getRootId(),
510                     0);
511    llvm::outs() << "</div>";
512    llvm::outs() << "</pre></div></body></html>\n";
513    return 0;
514  }
515
516  for (diff::NodeId Dst : DstTree) {
517    diff::NodeId Src = Diff.getMapped(DstTree, Dst);
518    if (PrintMatches && Src.isValid()) {
519      llvm::outs() << "Match ";
520      printNode(llvm::outs(), SrcTree, Src);
521      llvm::outs() << " to ";
522      printNode(llvm::outs(), DstTree, Dst);
523      llvm::outs() << "\n";
524    }
525    printDstChange(llvm::outs(), Diff, SrcTree, DstTree, Dst);
526  }
527  for (diff::NodeId Src : SrcTree) {
528    if (Diff.getMapped(SrcTree, Src).isInvalid()) {
529      llvm::outs() << "Delete ";
530      printNode(llvm::outs(), SrcTree, Src);
531      llvm::outs() << "\n";
532    }
533  }
534
535  return 0;
536}
537