1//===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This utility changes the input module to only contain a single function,
10// which is primarily used for debugging transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/SmallPtrSet.h"
16#include "llvm/Bitcode/BitcodeWriterPass.h"
17#include "llvm/IR/DataLayout.h"
18#include "llvm/IR/IRPrintingPasses.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/IR/LLVMContext.h"
21#include "llvm/IR/Module.h"
22#include "llvm/IRPrinter/IRPrintingPasses.h"
23#include "llvm/IRReader/IRReader.h"
24#include "llvm/Passes/PassBuilder.h"
25#include "llvm/Support/CommandLine.h"
26#include "llvm/Support/Error.h"
27#include "llvm/Support/FileSystem.h"
28#include "llvm/Support/InitLLVM.h"
29#include "llvm/Support/Regex.h"
30#include "llvm/Support/SourceMgr.h"
31#include "llvm/Support/SystemUtils.h"
32#include "llvm/Support/ToolOutputFile.h"
33#include "llvm/Transforms/IPO.h"
34#include "llvm/Transforms/IPO/BlockExtractor.h"
35#include "llvm/Transforms/IPO/ExtractGV.h"
36#include "llvm/Transforms/IPO/GlobalDCE.h"
37#include "llvm/Transforms/IPO/StripDeadPrototypes.h"
38#include "llvm/Transforms/IPO/StripSymbols.h"
39#include <memory>
40#include <utility>
41
42using namespace llvm;
43
44cl::OptionCategory ExtractCat("llvm-extract Options");
45
46// InputFilename - The filename to read from.
47static cl::opt<std::string> InputFilename(cl::Positional,
48                                          cl::desc("<input bitcode file>"),
49                                          cl::init("-"),
50                                          cl::value_desc("filename"));
51
52static cl::opt<std::string> OutputFilename("o",
53                                           cl::desc("Specify output filename"),
54                                           cl::value_desc("filename"),
55                                           cl::init("-"), cl::cat(ExtractCat));
56
57static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
58                           cl::cat(ExtractCat));
59
60static cl::opt<bool> DeleteFn("delete",
61                              cl::desc("Delete specified Globals from Module"),
62                              cl::cat(ExtractCat));
63
64static cl::opt<bool> KeepConstInit("keep-const-init",
65                              cl::desc("Keep initializers of constants"),
66                              cl::cat(ExtractCat));
67
68static cl::opt<bool>
69    Recursive("recursive", cl::desc("Recursively extract all called functions"),
70              cl::cat(ExtractCat));
71
72// ExtractFuncs - The functions to extract from the module.
73static cl::list<std::string>
74    ExtractFuncs("func", cl::desc("Specify function to extract"),
75                 cl::value_desc("function"), cl::cat(ExtractCat));
76
77// ExtractRegExpFuncs - The functions, matched via regular expression, to
78// extract from the module.
79static cl::list<std::string>
80    ExtractRegExpFuncs("rfunc",
81                       cl::desc("Specify function(s) to extract using a "
82                                "regular expression"),
83                       cl::value_desc("rfunction"), cl::cat(ExtractCat));
84
85// ExtractBlocks - The blocks to extract from the module.
86static cl::list<std::string> ExtractBlocks(
87    "bb",
88    cl::desc(
89        "Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
90        "Each pair will create a function.\n"
91        "If multiple basic blocks are specified in one pair,\n"
92        "the first block in the sequence should dominate the rest.\n"
93        "eg:\n"
94        "  --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
95        "  --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
96        "with bb2."),
97    cl::value_desc("function:bb1[;bb2...]"), cl::cat(ExtractCat));
98
99// ExtractAlias - The alias to extract from the module.
100static cl::list<std::string>
101    ExtractAliases("alias", cl::desc("Specify alias to extract"),
102                   cl::value_desc("alias"), cl::cat(ExtractCat));
103
104// ExtractRegExpAliases - The aliases, matched via regular expression, to
105// extract from the module.
106static cl::list<std::string>
107    ExtractRegExpAliases("ralias",
108                         cl::desc("Specify alias(es) to extract using a "
109                                  "regular expression"),
110                         cl::value_desc("ralias"), cl::cat(ExtractCat));
111
112// ExtractGlobals - The globals to extract from the module.
113static cl::list<std::string>
114    ExtractGlobals("glob", cl::desc("Specify global to extract"),
115                   cl::value_desc("global"), cl::cat(ExtractCat));
116
117// ExtractRegExpGlobals - The globals, matched via regular expression, to
118// extract from the module...
119static cl::list<std::string>
120    ExtractRegExpGlobals("rglob",
121                         cl::desc("Specify global(s) to extract using a "
122                                  "regular expression"),
123                         cl::value_desc("rglobal"), cl::cat(ExtractCat));
124
125static cl::opt<bool> OutputAssembly("S",
126                                    cl::desc("Write output as LLVM assembly"),
127                                    cl::Hidden, cl::cat(ExtractCat));
128
129static cl::opt<bool> PreserveBitcodeUseListOrder(
130    "preserve-bc-uselistorder",
131    cl::desc("Preserve use-list order when writing LLVM bitcode."),
132    cl::init(true), cl::Hidden, cl::cat(ExtractCat));
133
134static cl::opt<bool> PreserveAssemblyUseListOrder(
135    "preserve-ll-uselistorder",
136    cl::desc("Preserve use-list order when writing LLVM assembly."),
137    cl::init(false), cl::Hidden, cl::cat(ExtractCat));
138
139int main(int argc, char **argv) {
140  InitLLVM X(argc, argv);
141
142  LLVMContext Context;
143  cl::HideUnrelatedOptions(ExtractCat);
144  cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
145
146  // Use lazy loading, since we only care about selected global values.
147  SMDiagnostic Err;
148  std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
149
150  if (!M.get()) {
151    Err.print(argv[0], errs());
152    return 1;
153  }
154
155  // Use SetVector to avoid duplicates.
156  SetVector<GlobalValue *> GVs;
157
158  // Figure out which aliases we should extract.
159  for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
160    GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
161    if (!GA) {
162      errs() << argv[0] << ": program doesn't contain alias named '"
163             << ExtractAliases[i] << "'!\n";
164      return 1;
165    }
166    GVs.insert(GA);
167  }
168
169  // Extract aliases via regular expression matching.
170  for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
171    std::string Error;
172    Regex RegEx(ExtractRegExpAliases[i]);
173    if (!RegEx.isValid(Error)) {
174      errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
175        "invalid regex: " << Error;
176    }
177    bool match = false;
178    for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
179         GA != E; GA++) {
180      if (RegEx.match(GA->getName())) {
181        GVs.insert(&*GA);
182        match = true;
183      }
184    }
185    if (!match) {
186      errs() << argv[0] << ": program doesn't contain global named '"
187             << ExtractRegExpAliases[i] << "'!\n";
188      return 1;
189    }
190  }
191
192  // Figure out which globals we should extract.
193  for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
194    GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
195    if (!GV) {
196      errs() << argv[0] << ": program doesn't contain global named '"
197             << ExtractGlobals[i] << "'!\n";
198      return 1;
199    }
200    GVs.insert(GV);
201  }
202
203  // Extract globals via regular expression matching.
204  for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
205    std::string Error;
206    Regex RegEx(ExtractRegExpGlobals[i]);
207    if (!RegEx.isValid(Error)) {
208      errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
209        "invalid regex: " << Error;
210    }
211    bool match = false;
212    for (auto &GV : M->globals()) {
213      if (RegEx.match(GV.getName())) {
214        GVs.insert(&GV);
215        match = true;
216      }
217    }
218    if (!match) {
219      errs() << argv[0] << ": program doesn't contain global named '"
220             << ExtractRegExpGlobals[i] << "'!\n";
221      return 1;
222    }
223  }
224
225  // Figure out which functions we should extract.
226  for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
227    GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
228    if (!GV) {
229      errs() << argv[0] << ": program doesn't contain function named '"
230             << ExtractFuncs[i] << "'!\n";
231      return 1;
232    }
233    GVs.insert(GV);
234  }
235  // Extract functions via regular expression matching.
236  for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
237    std::string Error;
238    StringRef RegExStr = ExtractRegExpFuncs[i];
239    Regex RegEx(RegExStr);
240    if (!RegEx.isValid(Error)) {
241      errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
242        "invalid regex: " << Error;
243    }
244    bool match = false;
245    for (Module::iterator F = M->begin(), E = M->end(); F != E;
246         F++) {
247      if (RegEx.match(F->getName())) {
248        GVs.insert(&*F);
249        match = true;
250      }
251    }
252    if (!match) {
253      errs() << argv[0] << ": program doesn't contain global named '"
254             << ExtractRegExpFuncs[i] << "'!\n";
255      return 1;
256    }
257  }
258
259  // Figure out which BasicBlocks we should extract.
260  SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
261  for (StringRef StrPair : ExtractBlocks) {
262    SmallVector<StringRef, 16> BBNames;
263    auto BBInfo = StrPair.split(':');
264    // Get the function.
265    Function *F = M->getFunction(BBInfo.first);
266    if (!F) {
267      errs() << argv[0] << ": program doesn't contain a function named '"
268             << BBInfo.first << "'!\n";
269      return 1;
270    }
271    // Add the function to the materialize list, and store the basic block names
272    // to check after materialization.
273    GVs.insert(F);
274    BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
275    BBMap.push_back({F, std::move(BBNames)});
276  }
277
278  // Use *argv instead of argv[0] to work around a wrong GCC warning.
279  ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
280
281  if (Recursive) {
282    std::vector<llvm::Function *> Workqueue;
283    for (GlobalValue *GV : GVs) {
284      if (auto *F = dyn_cast<Function>(GV)) {
285        Workqueue.push_back(F);
286      }
287    }
288    while (!Workqueue.empty()) {
289      Function *F = &*Workqueue.back();
290      Workqueue.pop_back();
291      ExitOnErr(F->materialize());
292      for (auto &BB : *F) {
293        for (auto &I : BB) {
294          CallBase *CB = dyn_cast<CallBase>(&I);
295          if (!CB)
296            continue;
297          Function *CF = CB->getCalledFunction();
298          if (!CF)
299            continue;
300          if (CF->isDeclaration() || GVs.count(CF))
301            continue;
302          GVs.insert(CF);
303          Workqueue.push_back(CF);
304        }
305      }
306    }
307  }
308
309  auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
310
311  // Materialize requisite global values.
312  if (!DeleteFn) {
313    for (size_t i = 0, e = GVs.size(); i != e; ++i)
314      Materialize(*GVs[i]);
315  } else {
316    // Deleting. Materialize every GV that's *not* in GVs.
317    SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
318    for (auto &F : *M) {
319      if (!GVSet.count(&F))
320        Materialize(F);
321    }
322  }
323
324  {
325    std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
326    LoopAnalysisManager LAM;
327    FunctionAnalysisManager FAM;
328    CGSCCAnalysisManager CGAM;
329    ModuleAnalysisManager MAM;
330
331    PassBuilder PB;
332
333    PB.registerModuleAnalyses(MAM);
334    PB.registerCGSCCAnalyses(CGAM);
335    PB.registerFunctionAnalyses(FAM);
336    PB.registerLoopAnalyses(LAM);
337    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
338
339    ModulePassManager PM;
340    PM.addPass(ExtractGVPass(Gvs, DeleteFn, KeepConstInit));
341    PM.run(*M, MAM);
342
343    // Now that we have all the GVs we want, mark the module as fully
344    // materialized.
345    // FIXME: should the GVExtractionPass handle this?
346    ExitOnErr(M->materializeAll());
347  }
348
349  // Extract the specified basic blocks from the module and erase the existing
350  // functions.
351  if (!ExtractBlocks.empty()) {
352    // Figure out which BasicBlocks we should extract.
353    std::vector<std::vector<BasicBlock *>> GroupOfBBs;
354    for (auto &P : BBMap) {
355      std::vector<BasicBlock *> BBs;
356      for (StringRef BBName : P.second) {
357        // The function has been materialized, so add its matching basic blocks
358        // to the block extractor list, or fail if a name is not found.
359        auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) {
360          return BB.getName().equals(BBName);
361        });
362        if (Res == P.first->end()) {
363          errs() << argv[0] << ": function " << P.first->getName()
364                 << " doesn't contain a basic block named '" << BBName
365                 << "'!\n";
366          return 1;
367        }
368        BBs.push_back(&*Res);
369      }
370      GroupOfBBs.push_back(BBs);
371    }
372
373    LoopAnalysisManager LAM;
374    FunctionAnalysisManager FAM;
375    CGSCCAnalysisManager CGAM;
376    ModuleAnalysisManager MAM;
377
378    PassBuilder PB;
379
380    PB.registerModuleAnalyses(MAM);
381    PB.registerCGSCCAnalyses(CGAM);
382    PB.registerFunctionAnalyses(FAM);
383    PB.registerLoopAnalyses(LAM);
384    PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
385
386    ModulePassManager PM;
387    PM.addPass(BlockExtractorPass(std::move(GroupOfBBs), true));
388    PM.run(*M, MAM);
389  }
390
391  // In addition to deleting all other functions, we also want to spiff it
392  // up a little bit.  Do this now.
393
394  LoopAnalysisManager LAM;
395  FunctionAnalysisManager FAM;
396  CGSCCAnalysisManager CGAM;
397  ModuleAnalysisManager MAM;
398
399  PassBuilder PB;
400
401  PB.registerModuleAnalyses(MAM);
402  PB.registerCGSCCAnalyses(CGAM);
403  PB.registerFunctionAnalyses(FAM);
404  PB.registerLoopAnalyses(LAM);
405  PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
406
407  ModulePassManager PM;
408  if (!DeleteFn)
409    PM.addPass(GlobalDCEPass());
410  PM.addPass(StripDeadDebugInfoPass());
411  PM.addPass(StripDeadPrototypesPass());
412
413  std::error_code EC;
414  ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
415  if (EC) {
416    errs() << EC.message() << '\n';
417    return 1;
418  }
419
420  if (OutputAssembly)
421    PM.addPass(PrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
422  else if (Force || !CheckBitcodeOutputToConsole(Out.os()))
423    PM.addPass(BitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
424
425  PM.run(*M, MAM);
426
427  // Declare success.
428  Out.keep();
429
430  return 0;
431}
432