1//===-- clang-import-test.cpp - ASTImporter/ExternalASTSource testbed -----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "clang/AST/ASTContext.h"
10#include "clang/AST/ASTImporter.h"
11#include "clang/AST/DeclObjC.h"
12#include "clang/AST/ExternalASTMerger.h"
13#include "clang/Basic/Builtins.h"
14#include "clang/Basic/FileManager.h"
15#include "clang/Basic/IdentifierTable.h"
16#include "clang/Basic/SourceLocation.h"
17#include "clang/Basic/TargetInfo.h"
18#include "clang/Basic/TargetOptions.h"
19#include "clang/CodeGen/ModuleBuilder.h"
20#include "clang/Driver/Types.h"
21#include "clang/Frontend/ASTConsumers.h"
22#include "clang/Frontend/CompilerInstance.h"
23#include "clang/Frontend/MultiplexConsumer.h"
24#include "clang/Frontend/TextDiagnosticBuffer.h"
25#include "clang/Lex/Lexer.h"
26#include "clang/Lex/Preprocessor.h"
27#include "clang/Parse/ParseAST.h"
28
29#include "llvm/IR/LLVMContext.h"
30#include "llvm/IR/Module.h"
31#include "llvm/Support/CommandLine.h"
32#include "llvm/Support/Error.h"
33#include "llvm/Support/Host.h"
34#include "llvm/Support/Signals.h"
35
36#include <memory>
37#include <string>
38
39using namespace clang;
40
41static llvm::cl::opt<std::string> Expression(
42    "expression", llvm::cl::Required,
43    llvm::cl::desc("Path to a file containing the expression to parse"));
44
45static llvm::cl::list<std::string>
46    Imports("import",
47            llvm::cl::desc("Path to a file containing declarations to import"));
48
49static llvm::cl::opt<bool>
50    Direct("direct", llvm::cl::Optional,
51           llvm::cl::desc("Use the parsed declarations without indirection"));
52
53static llvm::cl::opt<bool> UseOrigins(
54    "use-origins", llvm::cl::Optional,
55    llvm::cl::desc(
56        "Use DeclContext origin information for more accurate lookups"));
57
58static llvm::cl::list<std::string>
59    ClangArgs("Xcc",
60              llvm::cl::desc("Argument to pass to the CompilerInvocation"),
61              llvm::cl::CommaSeparated);
62
63static llvm::cl::opt<std::string>
64    Input("x", llvm::cl::Optional,
65          llvm::cl::desc("The language to parse (default: c++)"),
66          llvm::cl::init("c++"));
67
68static llvm::cl::opt<bool> ObjCARC("objc-arc", llvm::cl::init(false),
69                                   llvm::cl::desc("Emable ObjC ARC"));
70
71static llvm::cl::opt<bool> DumpAST("dump-ast", llvm::cl::init(false),
72                                   llvm::cl::desc("Dump combined AST"));
73
74static llvm::cl::opt<bool> DumpIR("dump-ir", llvm::cl::init(false),
75                                  llvm::cl::desc("Dump IR from final parse"));
76
77namespace init_convenience {
78class TestDiagnosticConsumer : public DiagnosticConsumer {
79private:
80  std::unique_ptr<TextDiagnosticBuffer> Passthrough;
81  const LangOptions *LangOpts = nullptr;
82
83public:
84  TestDiagnosticConsumer()
85      : Passthrough(std::make_unique<TextDiagnosticBuffer>()) {}
86
87  void BeginSourceFile(const LangOptions &LangOpts,
88                       const Preprocessor *PP = nullptr) override {
89    this->LangOpts = &LangOpts;
90    return Passthrough->BeginSourceFile(LangOpts, PP);
91  }
92
93  void EndSourceFile() override {
94    this->LangOpts = nullptr;
95    Passthrough->EndSourceFile();
96  }
97
98  bool IncludeInDiagnosticCounts() const override {
99    return Passthrough->IncludeInDiagnosticCounts();
100  }
101
102private:
103  static void PrintSourceForLocation(const SourceLocation &Loc,
104                                     SourceManager &SM) {
105    const char *LocData = SM.getCharacterData(Loc, /*Invalid=*/nullptr);
106    unsigned LocColumn =
107        SM.getSpellingColumnNumber(Loc, /*Invalid=*/nullptr) - 1;
108    FileID FID = SM.getFileID(Loc);
109    llvm::MemoryBufferRef Buffer = SM.getBufferOrFake(FID, Loc);
110
111    assert(LocData >= Buffer.getBufferStart() &&
112           LocData < Buffer.getBufferEnd());
113
114    const char *LineBegin = LocData - LocColumn;
115
116    assert(LineBegin >= Buffer.getBufferStart());
117
118    const char *LineEnd = nullptr;
119
120    for (LineEnd = LineBegin; *LineEnd != '\n' && *LineEnd != '\r' &&
121                              LineEnd < Buffer.getBufferEnd();
122         ++LineEnd)
123      ;
124
125    llvm::StringRef LineString(LineBegin, LineEnd - LineBegin);
126
127    llvm::errs() << LineString << '\n';
128    llvm::errs().indent(LocColumn);
129    llvm::errs() << '^';
130    llvm::errs() << '\n';
131  }
132
133  void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
134                        const Diagnostic &Info) override {
135    if (Info.hasSourceManager() && LangOpts) {
136      SourceManager &SM = Info.getSourceManager();
137
138      if (Info.getLocation().isValid()) {
139        Info.getLocation().print(llvm::errs(), SM);
140        llvm::errs() << ": ";
141      }
142
143      SmallString<16> DiagText;
144      Info.FormatDiagnostic(DiagText);
145      llvm::errs() << DiagText << '\n';
146
147      if (Info.getLocation().isValid()) {
148        PrintSourceForLocation(Info.getLocation(), SM);
149      }
150
151      for (const CharSourceRange &Range : Info.getRanges()) {
152        bool Invalid = true;
153        StringRef Ref = Lexer::getSourceText(Range, SM, *LangOpts, &Invalid);
154        if (!Invalid) {
155          llvm::errs() << Ref << '\n';
156        }
157      }
158    }
159    DiagnosticConsumer::HandleDiagnostic(DiagLevel, Info);
160  }
161};
162
163std::unique_ptr<CompilerInstance> BuildCompilerInstance() {
164  auto Ins = std::make_unique<CompilerInstance>();
165  auto DC = std::make_unique<TestDiagnosticConsumer>();
166  const bool ShouldOwnClient = true;
167  Ins->createDiagnostics(DC.release(), ShouldOwnClient);
168
169  auto Inv = std::make_unique<CompilerInvocation>();
170
171  std::vector<const char *> ClangArgv(ClangArgs.size());
172  std::transform(ClangArgs.begin(), ClangArgs.end(), ClangArgv.begin(),
173                 [](const std::string &s) -> const char * { return s.data(); });
174  CompilerInvocation::CreateFromArgs(*Inv, ClangArgv, Ins->getDiagnostics());
175
176  {
177    using namespace driver::types;
178    ID Id = lookupTypeForTypeSpecifier(Input.c_str());
179    assert(Id != TY_INVALID);
180    if (isCXX(Id)) {
181      Inv->getLangOpts()->CPlusPlus = true;
182      Inv->getLangOpts()->CPlusPlus11 = true;
183      Inv->getHeaderSearchOpts().UseLibcxx = true;
184    }
185    if (isObjC(Id)) {
186      Inv->getLangOpts()->ObjC = 1;
187    }
188  }
189  Inv->getLangOpts()->ObjCAutoRefCount = ObjCARC;
190
191  Inv->getLangOpts()->Bool = true;
192  Inv->getLangOpts()->WChar = true;
193  Inv->getLangOpts()->Blocks = true;
194  Inv->getLangOpts()->DebuggerSupport = true;
195  Inv->getLangOpts()->SpellChecking = false;
196  Inv->getLangOpts()->ThreadsafeStatics = false;
197  Inv->getLangOpts()->AccessControl = false;
198  Inv->getLangOpts()->DollarIdents = true;
199  Inv->getLangOpts()->Exceptions = true;
200  Inv->getLangOpts()->CXXExceptions = true;
201  // Needed for testing dynamic_cast.
202  Inv->getLangOpts()->RTTI = true;
203  Inv->getCodeGenOpts().setDebugInfo(codegenoptions::FullDebugInfo);
204  Inv->getTargetOpts().Triple = llvm::sys::getDefaultTargetTriple();
205
206  Ins->setInvocation(std::move(Inv));
207
208  TargetInfo *TI = TargetInfo::CreateTargetInfo(
209      Ins->getDiagnostics(), Ins->getInvocation().TargetOpts);
210  Ins->setTarget(TI);
211  Ins->getTarget().adjust(Ins->getDiagnostics(), Ins->getLangOpts());
212  Ins->createFileManager();
213  Ins->createSourceManager(Ins->getFileManager());
214  Ins->createPreprocessor(TU_Complete);
215
216  return Ins;
217}
218
219std::unique_ptr<ASTContext>
220BuildASTContext(CompilerInstance &CI, SelectorTable &ST, Builtin::Context &BC) {
221  auto &PP = CI.getPreprocessor();
222  auto AST = std::make_unique<ASTContext>(
223      CI.getLangOpts(), CI.getSourceManager(),
224      PP.getIdentifierTable(), ST, BC, PP.TUKind);
225  AST->InitBuiltinTypes(CI.getTarget());
226  return AST;
227}
228
229std::unique_ptr<CodeGenerator> BuildCodeGen(CompilerInstance &CI,
230                                            llvm::LLVMContext &LLVMCtx) {
231  StringRef ModuleName("$__module");
232  return std::unique_ptr<CodeGenerator>(CreateLLVMCodeGen(
233      CI.getDiagnostics(), ModuleName, &CI.getVirtualFileSystem(),
234      CI.getHeaderSearchOpts(), CI.getPreprocessorOpts(), CI.getCodeGenOpts(),
235      LLVMCtx));
236}
237} // namespace init_convenience
238
239namespace {
240
241/// A container for a CompilerInstance (possibly with an ExternalASTMerger
242/// attached to its ASTContext).
243///
244/// Provides an accessor for the DeclContext origins associated with the
245/// ExternalASTMerger (or an empty list of origins if no ExternalASTMerger is
246/// attached).
247///
248/// This is the main unit of parsed source code maintained by clang-import-test.
249struct CIAndOrigins {
250  using OriginMap = clang::ExternalASTMerger::OriginMap;
251  std::unique_ptr<CompilerInstance> CI;
252
253  ASTContext &getASTContext() { return CI->getASTContext(); }
254  FileManager &getFileManager() { return CI->getFileManager(); }
255  const OriginMap &getOriginMap() {
256    static const OriginMap EmptyOriginMap{};
257    if (ExternalASTSource *Source = CI->getASTContext().getExternalSource())
258      return static_cast<ExternalASTMerger *>(Source)->GetOrigins();
259    return EmptyOriginMap;
260  }
261  DiagnosticConsumer &getDiagnosticClient() {
262    return CI->getDiagnosticClient();
263  }
264  CompilerInstance &getCompilerInstance() { return *CI; }
265};
266
267void AddExternalSource(CIAndOrigins &CI,
268                       llvm::MutableArrayRef<CIAndOrigins> Imports) {
269  ExternalASTMerger::ImporterTarget Target(
270      {CI.getASTContext(), CI.getFileManager()});
271  llvm::SmallVector<ExternalASTMerger::ImporterSource, 3> Sources;
272  for (CIAndOrigins &Import : Imports)
273    Sources.emplace_back(Import.getASTContext(), Import.getFileManager(),
274                         Import.getOriginMap());
275  auto ES = std::make_unique<ExternalASTMerger>(Target, Sources);
276  CI.getASTContext().setExternalSource(ES.release());
277  CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage();
278}
279
280CIAndOrigins BuildIndirect(CIAndOrigins &CI) {
281  CIAndOrigins IndirectCI{init_convenience::BuildCompilerInstance()};
282  auto ST = std::make_unique<SelectorTable>();
283  auto BC = std::make_unique<Builtin::Context>();
284  std::unique_ptr<ASTContext> AST = init_convenience::BuildASTContext(
285      IndirectCI.getCompilerInstance(), *ST, *BC);
286  IndirectCI.getCompilerInstance().setASTContext(AST.release());
287  AddExternalSource(IndirectCI, CI);
288  return IndirectCI;
289}
290
291llvm::Error ParseSource(const std::string &Path, CompilerInstance &CI,
292                        ASTConsumer &Consumer) {
293  SourceManager &SM = CI.getSourceManager();
294  auto FE = CI.getFileManager().getFileRef(Path);
295  if (!FE) {
296    llvm::consumeError(FE.takeError());
297    return llvm::make_error<llvm::StringError>(
298        llvm::Twine("No such file or directory: ", Path), std::error_code());
299  }
300  SM.setMainFileID(SM.createFileID(*FE, SourceLocation(), SrcMgr::C_User));
301  ParseAST(CI.getPreprocessor(), &Consumer, CI.getASTContext());
302  return llvm::Error::success();
303}
304
305llvm::Expected<CIAndOrigins> Parse(const std::string &Path,
306                                   llvm::MutableArrayRef<CIAndOrigins> Imports,
307                                   bool ShouldDumpAST, bool ShouldDumpIR) {
308  CIAndOrigins CI{init_convenience::BuildCompilerInstance()};
309  auto ST = std::make_unique<SelectorTable>();
310  auto BC = std::make_unique<Builtin::Context>();
311  std::unique_ptr<ASTContext> AST =
312      init_convenience::BuildASTContext(CI.getCompilerInstance(), *ST, *BC);
313  CI.getCompilerInstance().setASTContext(AST.release());
314  if (Imports.size())
315    AddExternalSource(CI, Imports);
316
317  std::vector<std::unique_ptr<ASTConsumer>> ASTConsumers;
318
319  auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
320  ASTConsumers.push_back(
321      init_convenience::BuildCodeGen(CI.getCompilerInstance(), *LLVMCtx));
322  auto &CG = *static_cast<CodeGenerator *>(ASTConsumers.back().get());
323
324  if (ShouldDumpAST)
325    ASTConsumers.push_back(CreateASTDumper(nullptr /*Dump to stdout.*/, "",
326                                           true, false, false, false,
327                                           clang::ADOF_Default));
328
329  CI.getDiagnosticClient().BeginSourceFile(
330      CI.getCompilerInstance().getLangOpts(),
331      &CI.getCompilerInstance().getPreprocessor());
332  MultiplexConsumer Consumers(std::move(ASTConsumers));
333  Consumers.Initialize(CI.getASTContext());
334
335  if (llvm::Error PE = ParseSource(Path, CI.getCompilerInstance(), Consumers))
336    return std::move(PE);
337  CI.getDiagnosticClient().EndSourceFile();
338  if (ShouldDumpIR)
339    CG.GetModule()->print(llvm::outs(), nullptr);
340  if (CI.getDiagnosticClient().getNumErrors())
341    return llvm::make_error<llvm::StringError>(
342        "Errors occurred while parsing the expression.", std::error_code());
343  return std::move(CI);
344}
345
346void Forget(CIAndOrigins &CI, llvm::MutableArrayRef<CIAndOrigins> Imports) {
347  llvm::SmallVector<ExternalASTMerger::ImporterSource, 3> Sources;
348  for (CIAndOrigins &Import : Imports)
349    Sources.push_back({Import.getASTContext(), Import.getFileManager(),
350                       Import.getOriginMap()});
351  ExternalASTSource *Source = CI.CI->getASTContext().getExternalSource();
352  auto *Merger = static_cast<ExternalASTMerger *>(Source);
353  Merger->RemoveSources(Sources);
354}
355
356} // end namespace
357
358int main(int argc, const char **argv) {
359  const bool DisableCrashReporting = true;
360  llvm::sys::PrintStackTraceOnErrorSignal(argv[0], DisableCrashReporting);
361  llvm::cl::ParseCommandLineOptions(argc, argv);
362  std::vector<CIAndOrigins> ImportCIs;
363  for (auto I : Imports) {
364    llvm::Expected<CIAndOrigins> ImportCI = Parse(I, {}, false, false);
365    if (auto E = ImportCI.takeError()) {
366      llvm::errs() << "error: " << llvm::toString(std::move(E)) << "\n";
367      exit(-1);
368    }
369    ImportCIs.push_back(std::move(*ImportCI));
370  }
371  std::vector<CIAndOrigins> IndirectCIs;
372  if (!Direct || UseOrigins) {
373    for (auto &ImportCI : ImportCIs) {
374      CIAndOrigins IndirectCI = BuildIndirect(ImportCI);
375      IndirectCIs.push_back(std::move(IndirectCI));
376    }
377  }
378  if (UseOrigins)
379    for (auto &ImportCI : ImportCIs)
380      IndirectCIs.push_back(std::move(ImportCI));
381  llvm::Expected<CIAndOrigins> ExpressionCI =
382      Parse(Expression, (Direct && !UseOrigins) ? ImportCIs : IndirectCIs,
383            DumpAST, DumpIR);
384  if (auto E = ExpressionCI.takeError()) {
385    llvm::errs() << "error: " << llvm::toString(std::move(E)) << "\n";
386    exit(-1);
387  }
388  Forget(*ExpressionCI, (Direct && !UseOrigins) ? ImportCIs : IndirectCIs);
389  return 0;
390}
391