1//===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This utility changes the input module to only contain a single function,
10// which is primarily used for debugging transformations.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/SmallPtrSet.h"
16#include "llvm/Bitcode/BitcodeWriterPass.h"
17#include "llvm/IR/DataLayout.h"
18#include "llvm/IR/IRPrintingPasses.h"
19#include "llvm/IR/Instructions.h"
20#include "llvm/IR/LLVMContext.h"
21#include "llvm/IR/LegacyPassManager.h"
22#include "llvm/IR/Module.h"
23#include "llvm/IRReader/IRReader.h"
24#include "llvm/Support/CommandLine.h"
25#include "llvm/Support/Error.h"
26#include "llvm/Support/FileSystem.h"
27#include "llvm/Support/InitLLVM.h"
28#include "llvm/Support/Regex.h"
29#include "llvm/Support/SourceMgr.h"
30#include "llvm/Support/SystemUtils.h"
31#include "llvm/Support/ToolOutputFile.h"
32#include "llvm/Transforms/IPO.h"
33#include <memory>
34#include <utility>
35using namespace llvm;
36
37cl::OptionCategory ExtractCat("llvm-extract Options");
38
39// InputFilename - The filename to read from.
40static cl::opt<std::string> InputFilename(cl::Positional,
41                                          cl::desc("<input bitcode file>"),
42                                          cl::init("-"),
43                                          cl::value_desc("filename"));
44
45static cl::opt<std::string> OutputFilename("o",
46                                           cl::desc("Specify output filename"),
47                                           cl::value_desc("filename"),
48                                           cl::init("-"), cl::cat(ExtractCat));
49
50static cl::opt<bool> Force("f", cl::desc("Enable binary output on terminals"),
51                           cl::cat(ExtractCat));
52
53static cl::opt<bool> DeleteFn("delete",
54                              cl::desc("Delete specified Globals from Module"),
55                              cl::cat(ExtractCat));
56
57static cl::opt<bool> KeepConstInit("keep-const-init",
58                              cl::desc("Keep initializers of constants"),
59                              cl::cat(ExtractCat));
60
61static cl::opt<bool>
62    Recursive("recursive", cl::desc("Recursively extract all called functions"),
63              cl::cat(ExtractCat));
64
65// ExtractFuncs - The functions to extract from the module.
66static cl::list<std::string>
67    ExtractFuncs("func", cl::desc("Specify function to extract"),
68                 cl::ZeroOrMore, cl::value_desc("function"),
69                 cl::cat(ExtractCat));
70
71// ExtractRegExpFuncs - The functions, matched via regular expression, to
72// extract from the module.
73static cl::list<std::string>
74    ExtractRegExpFuncs("rfunc",
75                       cl::desc("Specify function(s) to extract using a "
76                                "regular expression"),
77                       cl::ZeroOrMore, cl::value_desc("rfunction"),
78                       cl::cat(ExtractCat));
79
80// ExtractBlocks - The blocks to extract from the module.
81static cl::list<std::string> ExtractBlocks(
82    "bb",
83    cl::desc(
84        "Specify <function, basic block1[;basic block2...]> pairs to extract.\n"
85        "Each pair will create a function.\n"
86        "If multiple basic blocks are specified in one pair,\n"
87        "the first block in the sequence should dominate the rest.\n"
88        "eg:\n"
89        "  --bb=f:bb1;bb2 will extract one function with both bb1 and bb2;\n"
90        "  --bb=f:bb1 --bb=f:bb2 will extract two functions, one with bb1, one "
91        "with bb2."),
92    cl::ZeroOrMore, cl::value_desc("function:bb1[;bb2...]"),
93    cl::cat(ExtractCat));
94
95// ExtractAlias - The alias to extract from the module.
96static cl::list<std::string>
97    ExtractAliases("alias", cl::desc("Specify alias to extract"),
98                   cl::ZeroOrMore, cl::value_desc("alias"),
99                   cl::cat(ExtractCat));
100
101// ExtractRegExpAliases - The aliases, matched via regular expression, to
102// extract from the module.
103static cl::list<std::string>
104    ExtractRegExpAliases("ralias",
105                         cl::desc("Specify alias(es) to extract using a "
106                                  "regular expression"),
107                         cl::ZeroOrMore, cl::value_desc("ralias"),
108                         cl::cat(ExtractCat));
109
110// ExtractGlobals - The globals to extract from the module.
111static cl::list<std::string>
112    ExtractGlobals("glob", cl::desc("Specify global to extract"),
113                   cl::ZeroOrMore, cl::value_desc("global"),
114                   cl::cat(ExtractCat));
115
116// ExtractRegExpGlobals - The globals, matched via regular expression, to
117// extract from the module...
118static cl::list<std::string>
119    ExtractRegExpGlobals("rglob",
120                         cl::desc("Specify global(s) to extract using a "
121                                  "regular expression"),
122                         cl::ZeroOrMore, cl::value_desc("rglobal"),
123                         cl::cat(ExtractCat));
124
125static cl::opt<bool> OutputAssembly("S",
126                                    cl::desc("Write output as LLVM assembly"),
127                                    cl::Hidden, cl::cat(ExtractCat));
128
129static cl::opt<bool> PreserveBitcodeUseListOrder(
130    "preserve-bc-uselistorder",
131    cl::desc("Preserve use-list order when writing LLVM bitcode."),
132    cl::init(true), cl::Hidden, cl::cat(ExtractCat));
133
134static cl::opt<bool> PreserveAssemblyUseListOrder(
135    "preserve-ll-uselistorder",
136    cl::desc("Preserve use-list order when writing LLVM assembly."),
137    cl::init(false), cl::Hidden, cl::cat(ExtractCat));
138
139int main(int argc, char **argv) {
140  InitLLVM X(argc, argv);
141
142  LLVMContext Context;
143  cl::HideUnrelatedOptions(ExtractCat);
144  cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
145
146  // Use lazy loading, since we only care about selected global values.
147  SMDiagnostic Err;
148  std::unique_ptr<Module> M = getLazyIRFileModule(InputFilename, Err, Context);
149
150  if (!M.get()) {
151    Err.print(argv[0], errs());
152    return 1;
153  }
154
155  // Use SetVector to avoid duplicates.
156  SetVector<GlobalValue *> GVs;
157
158  // Figure out which aliases we should extract.
159  for (size_t i = 0, e = ExtractAliases.size(); i != e; ++i) {
160    GlobalAlias *GA = M->getNamedAlias(ExtractAliases[i]);
161    if (!GA) {
162      errs() << argv[0] << ": program doesn't contain alias named '"
163             << ExtractAliases[i] << "'!\n";
164      return 1;
165    }
166    GVs.insert(GA);
167  }
168
169  // Extract aliases via regular expression matching.
170  for (size_t i = 0, e = ExtractRegExpAliases.size(); i != e; ++i) {
171    std::string Error;
172    Regex RegEx(ExtractRegExpAliases[i]);
173    if (!RegEx.isValid(Error)) {
174      errs() << argv[0] << ": '" << ExtractRegExpAliases[i] << "' "
175        "invalid regex: " << Error;
176    }
177    bool match = false;
178    for (Module::alias_iterator GA = M->alias_begin(), E = M->alias_end();
179         GA != E; GA++) {
180      if (RegEx.match(GA->getName())) {
181        GVs.insert(&*GA);
182        match = true;
183      }
184    }
185    if (!match) {
186      errs() << argv[0] << ": program doesn't contain global named '"
187             << ExtractRegExpAliases[i] << "'!\n";
188      return 1;
189    }
190  }
191
192  // Figure out which globals we should extract.
193  for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
194    GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
195    if (!GV) {
196      errs() << argv[0] << ": program doesn't contain global named '"
197             << ExtractGlobals[i] << "'!\n";
198      return 1;
199    }
200    GVs.insert(GV);
201  }
202
203  // Extract globals via regular expression matching.
204  for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
205    std::string Error;
206    Regex RegEx(ExtractRegExpGlobals[i]);
207    if (!RegEx.isValid(Error)) {
208      errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
209        "invalid regex: " << Error;
210    }
211    bool match = false;
212    for (auto &GV : M->globals()) {
213      if (RegEx.match(GV.getName())) {
214        GVs.insert(&GV);
215        match = true;
216      }
217    }
218    if (!match) {
219      errs() << argv[0] << ": program doesn't contain global named '"
220             << ExtractRegExpGlobals[i] << "'!\n";
221      return 1;
222    }
223  }
224
225  // Figure out which functions we should extract.
226  for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
227    GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
228    if (!GV) {
229      errs() << argv[0] << ": program doesn't contain function named '"
230             << ExtractFuncs[i] << "'!\n";
231      return 1;
232    }
233    GVs.insert(GV);
234  }
235  // Extract functions via regular expression matching.
236  for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
237    std::string Error;
238    StringRef RegExStr = ExtractRegExpFuncs[i];
239    Regex RegEx(RegExStr);
240    if (!RegEx.isValid(Error)) {
241      errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
242        "invalid regex: " << Error;
243    }
244    bool match = false;
245    for (Module::iterator F = M->begin(), E = M->end(); F != E;
246         F++) {
247      if (RegEx.match(F->getName())) {
248        GVs.insert(&*F);
249        match = true;
250      }
251    }
252    if (!match) {
253      errs() << argv[0] << ": program doesn't contain global named '"
254             << ExtractRegExpFuncs[i] << "'!\n";
255      return 1;
256    }
257  }
258
259  // Figure out which BasicBlocks we should extract.
260  SmallVector<std::pair<Function *, SmallVector<StringRef, 16>>, 2> BBMap;
261  for (StringRef StrPair : ExtractBlocks) {
262    SmallVector<StringRef, 16> BBNames;
263    auto BBInfo = StrPair.split(':');
264    // Get the function.
265    Function *F = M->getFunction(BBInfo.first);
266    if (!F) {
267      errs() << argv[0] << ": program doesn't contain a function named '"
268             << BBInfo.first << "'!\n";
269      return 1;
270    }
271    // Add the function to the materialize list, and store the basic block names
272    // to check after materialization.
273    GVs.insert(F);
274    BBInfo.second.split(BBNames, ';', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
275    BBMap.push_back({F, std::move(BBNames)});
276  }
277
278  // Use *argv instead of argv[0] to work around a wrong GCC warning.
279  ExitOnError ExitOnErr(std::string(*argv) + ": error reading input: ");
280
281  if (Recursive) {
282    std::vector<llvm::Function *> Workqueue;
283    for (GlobalValue *GV : GVs) {
284      if (auto *F = dyn_cast<Function>(GV)) {
285        Workqueue.push_back(F);
286      }
287    }
288    while (!Workqueue.empty()) {
289      Function *F = &*Workqueue.back();
290      Workqueue.pop_back();
291      ExitOnErr(F->materialize());
292      for (auto &BB : *F) {
293        for (auto &I : BB) {
294          CallBase *CB = dyn_cast<CallBase>(&I);
295          if (!CB)
296            continue;
297          Function *CF = CB->getCalledFunction();
298          if (!CF)
299            continue;
300          if (CF->isDeclaration() || GVs.count(CF))
301            continue;
302          GVs.insert(CF);
303          Workqueue.push_back(CF);
304        }
305      }
306    }
307  }
308
309  auto Materialize = [&](GlobalValue &GV) { ExitOnErr(GV.materialize()); };
310
311  // Materialize requisite global values.
312  if (!DeleteFn) {
313    for (size_t i = 0, e = GVs.size(); i != e; ++i)
314      Materialize(*GVs[i]);
315  } else {
316    // Deleting. Materialize every GV that's *not* in GVs.
317    SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
318    for (auto &F : *M) {
319      if (!GVSet.count(&F))
320        Materialize(F);
321    }
322  }
323
324  {
325    std::vector<GlobalValue *> Gvs(GVs.begin(), GVs.end());
326    legacy::PassManager Extract;
327    Extract.add(createGVExtractionPass(Gvs, DeleteFn, KeepConstInit));
328    Extract.run(*M);
329
330    // Now that we have all the GVs we want, mark the module as fully
331    // materialized.
332    // FIXME: should the GVExtractionPass handle this?
333    ExitOnErr(M->materializeAll());
334  }
335
336  // Extract the specified basic blocks from the module and erase the existing
337  // functions.
338  if (!ExtractBlocks.empty()) {
339    // Figure out which BasicBlocks we should extract.
340    SmallVector<SmallVector<BasicBlock *, 16>, 4> GroupOfBBs;
341    for (auto &P : BBMap) {
342      SmallVector<BasicBlock *, 16> BBs;
343      for (StringRef BBName : P.second) {
344        // The function has been materialized, so add its matching basic blocks
345        // to the block extractor list, or fail if a name is not found.
346        auto Res = llvm::find_if(*P.first, [&](const BasicBlock &BB) {
347          return BB.getName().equals(BBName);
348        });
349        if (Res == P.first->end()) {
350          errs() << argv[0] << ": function " << P.first->getName()
351                 << " doesn't contain a basic block named '" << BBName
352                 << "'!\n";
353          return 1;
354        }
355        BBs.push_back(&*Res);
356      }
357      GroupOfBBs.push_back(BBs);
358    }
359
360    legacy::PassManager PM;
361    PM.add(createBlockExtractorPass(GroupOfBBs, true));
362    PM.run(*M);
363  }
364
365  // In addition to deleting all other functions, we also want to spiff it
366  // up a little bit.  Do this now.
367  legacy::PassManager Passes;
368
369  if (!DeleteFn)
370    Passes.add(createGlobalDCEPass());           // Delete unreachable globals
371  Passes.add(createStripDeadDebugInfoPass());    // Remove dead debug info
372  Passes.add(createStripDeadPrototypesPass());   // Remove dead func decls
373
374  std::error_code EC;
375  ToolOutputFile Out(OutputFilename, EC, sys::fs::OF_None);
376  if (EC) {
377    errs() << EC.message() << '\n';
378    return 1;
379  }
380
381  if (OutputAssembly)
382    Passes.add(
383        createPrintModulePass(Out.os(), "", PreserveAssemblyUseListOrder));
384  else if (Force || !CheckBitcodeOutputToConsole(Out.os()))
385    Passes.add(createBitcodeWriterPass(Out.os(), PreserveBitcodeUseListOrder));
386
387  Passes.run(*M.get());
388
389  // Declare success.
390  Out.keep();
391
392  return 0;
393}
394