1//===- llvm-extract.cpp - LLVM function extraction utility ----------------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This utility changes the input module to only contain a single function,
11// which is primarily used for debugging transformations.
12//
13//===----------------------------------------------------------------------===//
14
15#include "llvm/LLVMContext.h"
16#include "llvm/Module.h"
17#include "llvm/PassManager.h"
18#include "llvm/Assembly/PrintModulePass.h"
19#include "llvm/Bitcode/ReaderWriter.h"
20#include "llvm/Transforms/IPO.h"
21#include "llvm/Target/TargetData.h"
22#include "llvm/Support/CommandLine.h"
23#include "llvm/Support/IRReader.h"
24#include "llvm/Support/ManagedStatic.h"
25#include "llvm/Support/PrettyStackTrace.h"
26#include "llvm/Support/ToolOutputFile.h"
27#include "llvm/Support/SystemUtils.h"
28#include "llvm/Support/Signals.h"
29#include "llvm/Support/Regex.h"
30#include "llvm/ADT/SmallPtrSet.h"
31#include "llvm/ADT/SetVector.h"
32#include <memory>
33using namespace llvm;
34
35// InputFilename - The filename to read from.
36static cl::opt<std::string>
37InputFilename(cl::Positional, cl::desc("<input bitcode file>"),
38              cl::init("-"), cl::value_desc("filename"));
39
40static cl::opt<std::string>
41OutputFilename("o", cl::desc("Specify output filename"),
42               cl::value_desc("filename"), cl::init("-"));
43
44static cl::opt<bool>
45Force("f", cl::desc("Enable binary output on terminals"));
46
47static cl::opt<bool>
48DeleteFn("delete", cl::desc("Delete specified Globals from Module"));
49
50// ExtractFuncs - The functions to extract from the module.
51static cl::list<std::string>
52ExtractFuncs("func", cl::desc("Specify function to extract"),
53             cl::ZeroOrMore, cl::value_desc("function"));
54
55// ExtractRegExpFuncs - The functions, matched via regular expression, to
56// extract from the module.
57static cl::list<std::string>
58ExtractRegExpFuncs("rfunc", cl::desc("Specify function(s) to extract using a "
59                                     "regular expression"),
60                   cl::ZeroOrMore, cl::value_desc("rfunction"));
61
62// ExtractGlobals - The globals to extract from the module.
63static cl::list<std::string>
64ExtractGlobals("glob", cl::desc("Specify global to extract"),
65               cl::ZeroOrMore, cl::value_desc("global"));
66
67// ExtractRegExpGlobals - The globals, matched via regular expression, to
68// extract from the module...
69static cl::list<std::string>
70ExtractRegExpGlobals("rglob", cl::desc("Specify global(s) to extract using a "
71                                       "regular expression"),
72                     cl::ZeroOrMore, cl::value_desc("rglobal"));
73
74static cl::opt<bool>
75OutputAssembly("S",
76               cl::desc("Write output as LLVM assembly"), cl::Hidden);
77
78int main(int argc, char **argv) {
79  // Print a stack trace if we signal out.
80  sys::PrintStackTraceOnErrorSignal();
81  PrettyStackTraceProgram X(argc, argv);
82
83  LLVMContext &Context = getGlobalContext();
84  llvm_shutdown_obj Y;  // Call llvm_shutdown() on exit.
85  cl::ParseCommandLineOptions(argc, argv, "llvm extractor\n");
86
87  // Use lazy loading, since we only care about selected global values.
88  SMDiagnostic Err;
89  std::auto_ptr<Module> M;
90  M.reset(getLazyIRFileModule(InputFilename, Err, Context));
91
92  if (M.get() == 0) {
93    Err.print(argv[0], errs());
94    return 1;
95  }
96
97  // Use SetVector to avoid duplicates.
98  SetVector<GlobalValue *> GVs;
99
100  // Figure out which globals we should extract.
101  for (size_t i = 0, e = ExtractGlobals.size(); i != e; ++i) {
102    GlobalValue *GV = M->getNamedGlobal(ExtractGlobals[i]);
103    if (!GV) {
104      errs() << argv[0] << ": program doesn't contain global named '"
105             << ExtractGlobals[i] << "'!\n";
106      return 1;
107    }
108    GVs.insert(GV);
109  }
110
111  // Extract globals via regular expression matching.
112  for (size_t i = 0, e = ExtractRegExpGlobals.size(); i != e; ++i) {
113    std::string Error;
114    Regex RegEx(ExtractRegExpGlobals[i]);
115    if (!RegEx.isValid(Error)) {
116      errs() << argv[0] << ": '" << ExtractRegExpGlobals[i] << "' "
117        "invalid regex: " << Error;
118    }
119    bool match = false;
120    for (Module::global_iterator GV = M->global_begin(),
121           E = M->global_end(); GV != E; GV++) {
122      if (RegEx.match(GV->getName())) {
123        GVs.insert(&*GV);
124        match = true;
125      }
126    }
127    if (!match) {
128      errs() << argv[0] << ": program doesn't contain global named '"
129             << ExtractRegExpGlobals[i] << "'!\n";
130      return 1;
131    }
132  }
133
134  // Figure out which functions we should extract.
135  for (size_t i = 0, e = ExtractFuncs.size(); i != e; ++i) {
136    GlobalValue *GV = M->getFunction(ExtractFuncs[i]);
137    if (!GV) {
138      errs() << argv[0] << ": program doesn't contain function named '"
139             << ExtractFuncs[i] << "'!\n";
140      return 1;
141    }
142    GVs.insert(GV);
143  }
144  // Extract functions via regular expression matching.
145  for (size_t i = 0, e = ExtractRegExpFuncs.size(); i != e; ++i) {
146    std::string Error;
147    StringRef RegExStr = ExtractRegExpFuncs[i];
148    Regex RegEx(RegExStr);
149    if (!RegEx.isValid(Error)) {
150      errs() << argv[0] << ": '" << ExtractRegExpFuncs[i] << "' "
151        "invalid regex: " << Error;
152    }
153    bool match = false;
154    for (Module::iterator F = M->begin(), E = M->end(); F != E;
155         F++) {
156      if (RegEx.match(F->getName())) {
157        GVs.insert(&*F);
158        match = true;
159      }
160    }
161    if (!match) {
162      errs() << argv[0] << ": program doesn't contain global named '"
163             << ExtractRegExpFuncs[i] << "'!\n";
164      return 1;
165    }
166  }
167
168  // Materialize requisite global values.
169  if (!DeleteFn)
170    for (size_t i = 0, e = GVs.size(); i != e; ++i) {
171      GlobalValue *GV = GVs[i];
172      if (GV->isMaterializable()) {
173        std::string ErrInfo;
174        if (GV->Materialize(&ErrInfo)) {
175          errs() << argv[0] << ": error reading input: " << ErrInfo << "\n";
176          return 1;
177        }
178      }
179    }
180  else {
181    // Deleting. Materialize every GV that's *not* in GVs.
182    SmallPtrSet<GlobalValue *, 8> GVSet(GVs.begin(), GVs.end());
183    for (Module::global_iterator I = M->global_begin(), E = M->global_end();
184         I != E; ++I) {
185      GlobalVariable *G = I;
186      if (!GVSet.count(G) && G->isMaterializable()) {
187        std::string ErrInfo;
188        if (G->Materialize(&ErrInfo)) {
189          errs() << argv[0] << ": error reading input: " << ErrInfo << "\n";
190          return 1;
191        }
192      }
193    }
194    for (Module::iterator I = M->begin(), E = M->end(); I != E; ++I) {
195      Function *F = I;
196      if (!GVSet.count(F) && F->isMaterializable()) {
197        std::string ErrInfo;
198        if (F->Materialize(&ErrInfo)) {
199          errs() << argv[0] << ": error reading input: " << ErrInfo << "\n";
200          return 1;
201        }
202      }
203    }
204  }
205
206  // In addition to deleting all other functions, we also want to spiff it
207  // up a little bit.  Do this now.
208  PassManager Passes;
209  Passes.add(new TargetData(M.get())); // Use correct TargetData
210
211  std::vector<GlobalValue*> Gvs(GVs.begin(), GVs.end());
212
213  Passes.add(createGVExtractionPass(Gvs, DeleteFn));
214  if (!DeleteFn)
215    Passes.add(createGlobalDCEPass());           // Delete unreachable globals
216  Passes.add(createStripDeadDebugInfoPass());    // Remove dead debug info
217  Passes.add(createStripDeadPrototypesPass());   // Remove dead func decls
218
219  std::string ErrorInfo;
220  tool_output_file Out(OutputFilename.c_str(), ErrorInfo,
221                       raw_fd_ostream::F_Binary);
222  if (!ErrorInfo.empty()) {
223    errs() << ErrorInfo << '\n';
224    return 1;
225  }
226
227  if (OutputAssembly)
228    Passes.add(createPrintModulePass(&Out.os()));
229  else if (Force || !CheckBitcodeOutputToConsole(Out.os(), true))
230    Passes.add(createBitcodeWriterPass(Out.os()));
231
232  Passes.run(*M.get());
233
234  // Declare success.
235  Out.keep();
236
237  return 0;
238}
239