1//===- ASTSrcLocProcessor.cpp --------------------------------*- C++ -*----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "ASTSrcLocProcessor.h"
10
11#include "clang/Frontend/CompilerInstance.h"
12#include "llvm/Support/JSON.h"
13#include "llvm/Support/MemoryBuffer.h"
14
15using namespace clang::tooling;
16using namespace llvm;
17using namespace clang::ast_matchers;
18
19ASTSrcLocProcessor::ASTSrcLocProcessor(StringRef JsonPath)
20    : JsonPath(JsonPath) {
21
22  MatchFinder::MatchFinderOptions FinderOptions;
23
24  Finder = std::make_unique<MatchFinder>(std::move(FinderOptions));
25  Finder->addMatcher(
26      cxxRecordDecl(
27          isDefinition(),
28          isSameOrDerivedFrom(
29              namedDecl(
30                  hasAnyName(
31                      "clang::Stmt", "clang::Decl", "clang::CXXCtorInitializer",
32                      "clang::NestedNameSpecifierLoc",
33                      "clang::TemplateArgumentLoc", "clang::CXXBaseSpecifier",
34                      "clang::DeclarationNameInfo", "clang::TypeLoc"))
35                  .bind("nodeClade")),
36          optionally(isDerivedFrom(cxxRecordDecl().bind("derivedFrom"))))
37          .bind("className"),
38      this);
39  Finder->addMatcher(
40          cxxRecordDecl(isDefinition(), hasAnyName("clang::PointerLikeTypeLoc",
41                                                   "clang::TypeofLikeTypeLoc"))
42              .bind("templateName"),
43      this);
44}
45
46std::unique_ptr<clang::ASTConsumer>
47ASTSrcLocProcessor::createASTConsumer(clang::CompilerInstance &Compiler,
48                                      StringRef File) {
49  return Finder->newASTConsumer();
50}
51
52llvm::json::Object toJSON(llvm::StringMap<std::vector<StringRef>> const &Obj) {
53  using llvm::json::toJSON;
54
55  llvm::json::Object JsonObj;
56  for (const auto &Item : Obj) {
57    JsonObj[Item.first()] = Item.second;
58  }
59  return JsonObj;
60}
61
62llvm::json::Object toJSON(llvm::StringMap<std::string> const &Obj) {
63  using llvm::json::toJSON;
64
65  llvm::json::Object JsonObj;
66  for (const auto &Item : Obj) {
67    JsonObj[Item.first()] = Item.second;
68  }
69  return JsonObj;
70}
71
72llvm::json::Object toJSON(ClassData const &Obj) {
73  llvm::json::Object JsonObj;
74
75  if (!Obj.ASTClassLocations.empty())
76    JsonObj["sourceLocations"] = Obj.ASTClassLocations;
77  if (!Obj.ASTClassRanges.empty())
78    JsonObj["sourceRanges"] = Obj.ASTClassRanges;
79  if (!Obj.TemplateParms.empty())
80    JsonObj["templateParms"] = Obj.TemplateParms;
81  if (!Obj.TypeSourceInfos.empty())
82    JsonObj["typeSourceInfos"] = Obj.TypeSourceInfos;
83  if (!Obj.TypeLocs.empty())
84    JsonObj["typeLocs"] = Obj.TypeLocs;
85  if (!Obj.NestedNameLocs.empty())
86    JsonObj["nestedNameLocs"] = Obj.NestedNameLocs;
87  if (!Obj.DeclNameInfos.empty())
88    JsonObj["declNameInfos"] = Obj.DeclNameInfos;
89  return JsonObj;
90}
91
92llvm::json::Object toJSON(llvm::StringMap<ClassData> const &Obj) {
93  using llvm::json::toJSON;
94
95  llvm::json::Object JsonObj;
96  for (const auto &Item : Obj)
97    JsonObj[Item.first()] = ::toJSON(Item.second);
98  return JsonObj;
99}
100
101void WriteJSON(StringRef JsonPath, llvm::json::Object &&ClassInheritance,
102               llvm::json::Object &&ClassesInClade,
103               llvm::json::Object &&ClassEntries) {
104  llvm::json::Object JsonObj;
105
106  using llvm::json::toJSON;
107
108  JsonObj["classInheritance"] = std::move(ClassInheritance);
109  JsonObj["classesInClade"] = std::move(ClassesInClade);
110  JsonObj["classEntries"] = std::move(ClassEntries);
111
112  llvm::json::Value JsonVal(std::move(JsonObj));
113
114  bool WriteChange = false;
115  std::string OutString;
116  if (auto ExistingOrErr = MemoryBuffer::getFile(JsonPath, /*IsText=*/true)) {
117    raw_string_ostream Out(OutString);
118    Out << formatv("{0:2}", JsonVal);
119    if (ExistingOrErr.get()->getBuffer() == Out.str())
120      return;
121    WriteChange = true;
122  }
123
124  std::error_code EC;
125  llvm::raw_fd_ostream JsonOut(JsonPath, EC, llvm::sys::fs::OF_Text);
126  if (EC)
127    return;
128
129  if (WriteChange)
130    JsonOut << OutString;
131  else
132    JsonOut << formatv("{0:2}", JsonVal);
133}
134
135void ASTSrcLocProcessor::generate() {
136  WriteJSON(JsonPath, ::toJSON(ClassInheritance), ::toJSON(ClassesInClade),
137            ::toJSON(ClassEntries));
138}
139
140void ASTSrcLocProcessor::generateEmpty() { WriteJSON(JsonPath, {}, {}, {}); }
141
142std::vector<std::string>
143CaptureMethods(std::string TypeString, const clang::CXXRecordDecl *ASTClass,
144               const MatchFinder::MatchResult &Result) {
145
146  auto publicAccessor = [](auto... InnerMatcher) {
147    return cxxMethodDecl(isPublic(), parameterCountIs(0), isConst(),
148                         InnerMatcher...);
149  };
150
151  auto BoundNodesVec = match(
152      findAll(
153          publicAccessor(
154              ofClass(cxxRecordDecl(
155                  equalsNode(ASTClass),
156                  optionally(isDerivedFrom(
157                      cxxRecordDecl(hasAnyName("clang::Stmt", "clang::Decl"))
158                          .bind("stmtOrDeclBase"))),
159                  optionally(isDerivedFrom(
160                      cxxRecordDecl(hasName("clang::Expr")).bind("exprBase"))),
161                  optionally(
162                      isDerivedFrom(cxxRecordDecl(hasName("clang::TypeLoc"))
163                                        .bind("typeLocBase"))))),
164              returns(asString(TypeString)))
165              .bind("classMethod")),
166      *ASTClass, *Result.Context);
167
168  std::vector<std::string> Methods;
169  for (const auto &BN : BoundNodesVec) {
170    if (const auto *Node = BN.getNodeAs<clang::NamedDecl>("classMethod")) {
171      const auto *StmtOrDeclBase =
172          BN.getNodeAs<clang::CXXRecordDecl>("stmtOrDeclBase");
173      const auto *TypeLocBase =
174          BN.getNodeAs<clang::CXXRecordDecl>("typeLocBase");
175      const auto *ExprBase = BN.getNodeAs<clang::CXXRecordDecl>("exprBase");
176      // The clang AST has several methods on base classes which are overriden
177      // pseudo-virtually by derived classes.
178      // We record only the pseudo-virtual methods on the base classes to
179      // avoid duplication.
180      if (StmtOrDeclBase &&
181          (Node->getName() == "getBeginLoc" || Node->getName() == "getEndLoc" ||
182           Node->getName() == "getSourceRange"))
183        continue;
184      if (ExprBase && Node->getName() == "getExprLoc")
185        continue;
186      if (TypeLocBase && Node->getName() == "getLocalSourceRange")
187        continue;
188      if ((ASTClass->getName() == "PointerLikeTypeLoc" ||
189           ASTClass->getName() == "TypeofLikeTypeLoc") &&
190          Node->getName() == "getLocalSourceRange")
191        continue;
192      Methods.push_back(Node->getName().str());
193    }
194  }
195  return Methods;
196}
197
198void ASTSrcLocProcessor::run(const MatchFinder::MatchResult &Result) {
199
200  const auto *ASTClass =
201      Result.Nodes.getNodeAs<clang::CXXRecordDecl>("className");
202
203  StringRef CladeName;
204  if (ASTClass) {
205    if (const auto *NodeClade =
206            Result.Nodes.getNodeAs<clang::CXXRecordDecl>("nodeClade"))
207      CladeName = NodeClade->getName();
208  } else {
209    ASTClass = Result.Nodes.getNodeAs<clang::CXXRecordDecl>("templateName");
210    CladeName = "TypeLoc";
211  }
212
213  StringRef ClassName = ASTClass->getName();
214
215  ClassData CD;
216
217  CD.ASTClassLocations =
218      CaptureMethods("class clang::SourceLocation", ASTClass, Result);
219  CD.ASTClassRanges =
220      CaptureMethods("class clang::SourceRange", ASTClass, Result);
221  CD.TypeSourceInfos =
222      CaptureMethods("class clang::TypeSourceInfo *", ASTClass, Result);
223  CD.TypeLocs = CaptureMethods("class clang::TypeLoc", ASTClass, Result);
224  CD.NestedNameLocs =
225      CaptureMethods("class clang::NestedNameSpecifierLoc", ASTClass, Result);
226  CD.DeclNameInfos =
227      CaptureMethods("struct clang::DeclarationNameInfo", ASTClass, Result);
228  auto DI = CaptureMethods("const struct clang::DeclarationNameInfo &",
229                           ASTClass, Result);
230  CD.DeclNameInfos.insert(CD.DeclNameInfos.end(), DI.begin(), DI.end());
231
232  if (const auto *DerivedFrom =
233          Result.Nodes.getNodeAs<clang::CXXRecordDecl>("derivedFrom")) {
234
235    if (const auto *Templ =
236            llvm::dyn_cast<clang::ClassTemplateSpecializationDecl>(
237                DerivedFrom)) {
238
239      const auto &TArgs = Templ->getTemplateArgs();
240
241      SmallString<256> TArgsString;
242      llvm::raw_svector_ostream OS(TArgsString);
243      OS << DerivedFrom->getName() << '<';
244
245      clang::PrintingPolicy PPol(Result.Context->getLangOpts());
246      PPol.TerseOutput = true;
247
248      for (unsigned I = 0; I < TArgs.size(); ++I) {
249        if (I > 0)
250          OS << ", ";
251        TArgs.get(I).getAsType().print(OS, PPol);
252      }
253      OS << '>';
254
255      ClassInheritance[ClassName] = TArgsString.str().str();
256    } else {
257      ClassInheritance[ClassName] = DerivedFrom->getName().str();
258    }
259  }
260
261  if (const auto *Templ = ASTClass->getDescribedClassTemplate()) {
262    if (auto *TParams = Templ->getTemplateParameters()) {
263      for (const auto &TParam : *TParams) {
264        CD.TemplateParms.push_back(TParam->getName().str());
265      }
266    }
267  }
268
269  ClassEntries[ClassName] = CD;
270  ClassesInClade[CladeName].push_back(ClassName);
271}
272