CodeExtractor.cpp revision 309124
1//===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements the interface to tear out a code region, such as an
11// individual loop or a parallel section, into a new function, replacing it with
12// a call to the new function.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Transforms/Utils/CodeExtractor.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/ADT/SetVector.h"
19#include "llvm/ADT/StringExtras.h"
20#include "llvm/Analysis/LoopInfo.h"
21#include "llvm/Analysis/RegionInfo.h"
22#include "llvm/Analysis/RegionIterator.h"
23#include "llvm/IR/Constants.h"
24#include "llvm/IR/DerivedTypes.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Instructions.h"
27#include "llvm/IR/Intrinsics.h"
28#include "llvm/IR/LLVMContext.h"
29#include "llvm/IR/Module.h"
30#include "llvm/IR/Verifier.h"
31#include "llvm/Pass.h"
32#include "llvm/Support/CommandLine.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/ErrorHandling.h"
35#include "llvm/Support/raw_ostream.h"
36#include "llvm/Transforms/Utils/BasicBlockUtils.h"
37#include <algorithm>
38#include <set>
39using namespace llvm;
40
41#define DEBUG_TYPE "code-extractor"
42
43// Provide a command-line option to aggregate function arguments into a struct
44// for functions produced by the code extractor. This is useful when converting
45// extracted functions to pthread-based code, as only one argument (void*) can
46// be passed in to pthread_create().
47static cl::opt<bool>
48AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
49                 cl::desc("Aggregate arguments to code-extracted functions"));
50
51/// \brief Test whether a block is valid for extraction.
52static bool isBlockValidForExtraction(const BasicBlock &BB) {
53  // Landing pads must be in the function where they were inserted for cleanup.
54  if (BB.isEHPad())
55    return false;
56
57  // Don't hoist code containing allocas, invokes, or vastarts.
58  for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
59    if (isa<AllocaInst>(I) || isa<InvokeInst>(I))
60      return false;
61    if (const CallInst *CI = dyn_cast<CallInst>(I))
62      if (const Function *F = CI->getCalledFunction())
63        if (F->getIntrinsicID() == Intrinsic::vastart)
64          return false;
65  }
66
67  return true;
68}
69
70/// \brief Build a set of blocks to extract if the input blocks are viable.
71template <typename IteratorT>
72static SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin,
73                                                       IteratorT BBEnd) {
74  SetVector<BasicBlock *> Result;
75
76  assert(BBBegin != BBEnd);
77
78  // Loop over the blocks, adding them to our set-vector, and aborting with an
79  // empty set if we encounter invalid blocks.
80  do {
81    if (!Result.insert(*BBBegin))
82      llvm_unreachable("Repeated basic blocks in extraction input");
83
84    if (!isBlockValidForExtraction(**BBBegin)) {
85      Result.clear();
86      return Result;
87    }
88  } while (++BBBegin != BBEnd);
89
90#ifndef NDEBUG
91  for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()),
92                                         E = Result.end();
93       I != E; ++I)
94    for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I);
95         PI != PE; ++PI)
96      assert(Result.count(*PI) &&
97             "No blocks in this region may have entries from outside the region"
98             " except for the first block!");
99#endif
100
101  return Result;
102}
103
104/// \brief Helper to call buildExtractionBlockSet with an ArrayRef.
105static SetVector<BasicBlock *>
106buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) {
107  return buildExtractionBlockSet(BBs.begin(), BBs.end());
108}
109
110/// \brief Helper to call buildExtractionBlockSet with a RegionNode.
111static SetVector<BasicBlock *>
112buildExtractionBlockSet(const RegionNode &RN) {
113  if (!RN.isSubRegion())
114    // Just a single BasicBlock.
115    return buildExtractionBlockSet(RN.getNodeAs<BasicBlock>());
116
117  const Region &R = *RN.getNodeAs<Region>();
118
119  return buildExtractionBlockSet(R.block_begin(), R.block_end());
120}
121
122CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs)
123  : DT(nullptr), AggregateArgs(AggregateArgs||AggregateArgsOpt),
124    Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {}
125
126CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
127                             bool AggregateArgs)
128  : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
129    Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {}
130
131CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs)
132  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
133    Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {}
134
135CodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN,
136                             bool AggregateArgs)
137  : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt),
138    Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {}
139
140/// definedInRegion - Return true if the specified value is defined in the
141/// extracted region.
142static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
143  if (Instruction *I = dyn_cast<Instruction>(V))
144    if (Blocks.count(I->getParent()))
145      return true;
146  return false;
147}
148
149/// definedInCaller - Return true if the specified value is defined in the
150/// function being code extracted, but not in the region being extracted.
151/// These values must be passed in as live-ins to the function.
152static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
153  if (isa<Argument>(V)) return true;
154  if (Instruction *I = dyn_cast<Instruction>(V))
155    if (!Blocks.count(I->getParent()))
156      return true;
157  return false;
158}
159
160void CodeExtractor::findInputsOutputs(ValueSet &Inputs,
161                                      ValueSet &Outputs) const {
162  for (BasicBlock *BB : Blocks) {
163    // If a used value is defined outside the region, it's an input.  If an
164    // instruction is used outside the region, it's an output.
165    for (Instruction &II : *BB) {
166      for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
167           ++OI)
168        if (definedInCaller(Blocks, *OI))
169          Inputs.insert(*OI);
170
171      for (User *U : II.users())
172        if (!definedInRegion(Blocks, U)) {
173          Outputs.insert(&II);
174          break;
175        }
176    }
177  }
178}
179
180/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the
181/// region, we need to split the entry block of the region so that the PHI node
182/// is easier to deal with.
183void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) {
184  unsigned NumPredsFromRegion = 0;
185  unsigned NumPredsOutsideRegion = 0;
186
187  if (Header != &Header->getParent()->getEntryBlock()) {
188    PHINode *PN = dyn_cast<PHINode>(Header->begin());
189    if (!PN) return;  // No PHI nodes.
190
191    // If the header node contains any PHI nodes, check to see if there is more
192    // than one entry from outside the region.  If so, we need to sever the
193    // header block into two.
194    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
195      if (Blocks.count(PN->getIncomingBlock(i)))
196        ++NumPredsFromRegion;
197      else
198        ++NumPredsOutsideRegion;
199
200    // If there is one (or fewer) predecessor from outside the region, we don't
201    // need to do anything special.
202    if (NumPredsOutsideRegion <= 1) return;
203  }
204
205  // Otherwise, we need to split the header block into two pieces: one
206  // containing PHI nodes merging values from outside of the region, and a
207  // second that contains all of the code for the block and merges back any
208  // incoming values from inside of the region.
209  BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI()->getIterator();
210  BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs,
211                                              Header->getName()+".ce");
212
213  // We only want to code extract the second block now, and it becomes the new
214  // header of the region.
215  BasicBlock *OldPred = Header;
216  Blocks.remove(OldPred);
217  Blocks.insert(NewBB);
218  Header = NewBB;
219
220  // Okay, update dominator sets. The blocks that dominate the new one are the
221  // blocks that dominate TIBB plus the new block itself.
222  if (DT)
223    DT->splitBlock(NewBB);
224
225  // Okay, now we need to adjust the PHI nodes and any branches from within the
226  // region to go to the new header block instead of the old header block.
227  if (NumPredsFromRegion) {
228    PHINode *PN = cast<PHINode>(OldPred->begin());
229    // Loop over all of the predecessors of OldPred that are in the region,
230    // changing them to branch to NewBB instead.
231    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
232      if (Blocks.count(PN->getIncomingBlock(i))) {
233        TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator();
234        TI->replaceUsesOfWith(OldPred, NewBB);
235      }
236
237    // Okay, everything within the region is now branching to the right block, we
238    // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
239    for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
240      PHINode *PN = cast<PHINode>(AfterPHIs);
241      // Create a new PHI node in the new region, which has an incoming value
242      // from OldPred of PN.
243      PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
244                                       PN->getName() + ".ce", &NewBB->front());
245      NewPN->addIncoming(PN, OldPred);
246
247      // Loop over all of the incoming value in PN, moving them to NewPN if they
248      // are from the extracted region.
249      for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
250        if (Blocks.count(PN->getIncomingBlock(i))) {
251          NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
252          PN->removeIncomingValue(i);
253          --i;
254        }
255      }
256    }
257  }
258}
259
260void CodeExtractor::splitReturnBlocks() {
261  for (BasicBlock *Block : Blocks)
262    if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
263      BasicBlock *New =
264          Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
265      if (DT) {
266        // Old dominates New. New node dominates all other nodes dominated
267        // by Old.
268        DomTreeNode *OldNode = DT->getNode(Block);
269        SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
270                                               OldNode->end());
271
272        DomTreeNode *NewNode = DT->addNewBlock(New, Block);
273
274        for (DomTreeNode *I : Children)
275          DT->changeImmediateDominator(I, NewNode);
276      }
277    }
278}
279
280/// constructFunction - make a function based on inputs and outputs, as follows:
281/// f(in0, ..., inN, out0, ..., outN)
282///
283Function *CodeExtractor::constructFunction(const ValueSet &inputs,
284                                           const ValueSet &outputs,
285                                           BasicBlock *header,
286                                           BasicBlock *newRootNode,
287                                           BasicBlock *newHeader,
288                                           Function *oldFunction,
289                                           Module *M) {
290  DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
291  DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
292
293  // This function returns unsigned, outputs will go back by reference.
294  switch (NumExitBlocks) {
295  case 0:
296  case 1: RetTy = Type::getVoidTy(header->getContext()); break;
297  case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
298  default: RetTy = Type::getInt16Ty(header->getContext()); break;
299  }
300
301  std::vector<Type*> paramTy;
302
303  // Add the types of the input values to the function's argument list
304  for (Value *value : inputs) {
305    DEBUG(dbgs() << "value used in func: " << *value << "\n");
306    paramTy.push_back(value->getType());
307  }
308
309  // Add the types of the output values to the function's argument list.
310  for (Value *output : outputs) {
311    DEBUG(dbgs() << "instr used in func: " << *output << "\n");
312    if (AggregateArgs)
313      paramTy.push_back(output->getType());
314    else
315      paramTy.push_back(PointerType::getUnqual(output->getType()));
316  }
317
318  DEBUG({
319    dbgs() << "Function type: " << *RetTy << " f(";
320    for (Type *i : paramTy)
321      dbgs() << *i << ", ";
322    dbgs() << ")\n";
323  });
324
325  StructType *StructTy;
326  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
327    StructTy = StructType::get(M->getContext(), paramTy);
328    paramTy.clear();
329    paramTy.push_back(PointerType::getUnqual(StructTy));
330  }
331  FunctionType *funcType =
332                  FunctionType::get(RetTy, paramTy, false);
333
334  // Create the new function
335  Function *newFunction = Function::Create(funcType,
336                                           GlobalValue::InternalLinkage,
337                                           oldFunction->getName() + "_" +
338                                           header->getName(), M);
339  // If the old function is no-throw, so is the new one.
340  if (oldFunction->doesNotThrow())
341    newFunction->setDoesNotThrow();
342
343  newFunction->getBasicBlockList().push_back(newRootNode);
344
345  // Create an iterator to name all of the arguments we inserted.
346  Function::arg_iterator AI = newFunction->arg_begin();
347
348  // Rewrite all users of the inputs in the extracted region to use the
349  // arguments (or appropriate addressing into struct) instead.
350  for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
351    Value *RewriteVal;
352    if (AggregateArgs) {
353      Value *Idx[2];
354      Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
355      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
356      TerminatorInst *TI = newFunction->begin()->getTerminator();
357      GetElementPtrInst *GEP = GetElementPtrInst::Create(
358          StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
359      RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
360    } else
361      RewriteVal = &*AI++;
362
363    std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end());
364    for (User *use : Users)
365      if (Instruction *inst = dyn_cast<Instruction>(use))
366        if (Blocks.count(inst->getParent()))
367          inst->replaceUsesOfWith(inputs[i], RewriteVal);
368  }
369
370  // Set names for input and output arguments.
371  if (!AggregateArgs) {
372    AI = newFunction->arg_begin();
373    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
374      AI->setName(inputs[i]->getName());
375    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
376      AI->setName(outputs[i]->getName()+".out");
377  }
378
379  // Rewrite branches to basic blocks outside of the loop to new dummy blocks
380  // within the new function. This must be done before we lose track of which
381  // blocks were originally in the code region.
382  std::vector<User*> Users(header->user_begin(), header->user_end());
383  for (unsigned i = 0, e = Users.size(); i != e; ++i)
384    // The BasicBlock which contains the branch is not in the region
385    // modify the branch target to a new block
386    if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i]))
387      if (!Blocks.count(TI->getParent()) &&
388          TI->getParent()->getParent() == oldFunction)
389        TI->replaceUsesOfWith(header, newHeader);
390
391  return newFunction;
392}
393
394/// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI
395/// that uses the value within the basic block, and return the predecessor
396/// block associated with that use, or return 0 if none is found.
397static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) {
398  for (Use &U : Used->uses()) {
399     PHINode *P = dyn_cast<PHINode>(U.getUser());
400     if (P && P->getParent() == BB)
401       return P->getIncomingBlock(U);
402  }
403
404  return nullptr;
405}
406
407/// emitCallAndSwitchStatement - This method sets up the caller side by adding
408/// the call instruction, splitting any PHI nodes in the header block as
409/// necessary.
410void CodeExtractor::
411emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer,
412                           ValueSet &inputs, ValueSet &outputs) {
413  // Emit a call to the new function, passing in: *pointer to struct (if
414  // aggregating parameters), or plan inputs and allocated memory for outputs
415  std::vector<Value*> params, StructValues, ReloadOutputs, Reloads;
416
417  LLVMContext &Context = newFunction->getContext();
418
419  // Add inputs as params, or to be filled into the struct
420  for (Value *input : inputs)
421    if (AggregateArgs)
422      StructValues.push_back(input);
423    else
424      params.push_back(input);
425
426  // Create allocas for the outputs
427  for (Value *output : outputs) {
428    if (AggregateArgs) {
429      StructValues.push_back(output);
430    } else {
431      AllocaInst *alloca =
432          new AllocaInst(output->getType(), nullptr, output->getName() + ".loc",
433                         &codeReplacer->getParent()->front().front());
434      ReloadOutputs.push_back(alloca);
435      params.push_back(alloca);
436    }
437  }
438
439  StructType *StructArgTy = nullptr;
440  AllocaInst *Struct = nullptr;
441  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
442    std::vector<Type*> ArgTypes;
443    for (ValueSet::iterator v = StructValues.begin(),
444           ve = StructValues.end(); v != ve; ++v)
445      ArgTypes.push_back((*v)->getType());
446
447    // Allocate a struct at the beginning of this function
448    StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
449    Struct = new AllocaInst(StructArgTy, nullptr, "structArg",
450                            &codeReplacer->getParent()->front().front());
451    params.push_back(Struct);
452
453    for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
454      Value *Idx[2];
455      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
456      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
457      GetElementPtrInst *GEP = GetElementPtrInst::Create(
458          StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
459      codeReplacer->getInstList().push_back(GEP);
460      StoreInst *SI = new StoreInst(StructValues[i], GEP);
461      codeReplacer->getInstList().push_back(SI);
462    }
463  }
464
465  // Emit the call to the function
466  CallInst *call = CallInst::Create(newFunction, params,
467                                    NumExitBlocks > 1 ? "targetBlock" : "");
468  codeReplacer->getInstList().push_back(call);
469
470  Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
471  unsigned FirstOut = inputs.size();
472  if (!AggregateArgs)
473    std::advance(OutputArgBegin, inputs.size());
474
475  // Reload the outputs passed in by reference
476  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
477    Value *Output = nullptr;
478    if (AggregateArgs) {
479      Value *Idx[2];
480      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
481      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
482      GetElementPtrInst *GEP = GetElementPtrInst::Create(
483          StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
484      codeReplacer->getInstList().push_back(GEP);
485      Output = GEP;
486    } else {
487      Output = ReloadOutputs[i];
488    }
489    LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
490    Reloads.push_back(load);
491    codeReplacer->getInstList().push_back(load);
492    std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end());
493    for (unsigned u = 0, e = Users.size(); u != e; ++u) {
494      Instruction *inst = cast<Instruction>(Users[u]);
495      if (!Blocks.count(inst->getParent()))
496        inst->replaceUsesOfWith(outputs[i], load);
497    }
498  }
499
500  // Now we can emit a switch statement using the call as a value.
501  SwitchInst *TheSwitch =
502      SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
503                         codeReplacer, 0, codeReplacer);
504
505  // Since there may be multiple exits from the original region, make the new
506  // function return an unsigned, switch on that number.  This loop iterates
507  // over all of the blocks in the extracted region, updating any terminator
508  // instructions in the to-be-extracted region that branch to blocks that are
509  // not in the region to be extracted.
510  std::map<BasicBlock*, BasicBlock*> ExitBlockMap;
511
512  unsigned switchVal = 0;
513  for (BasicBlock *Block : Blocks) {
514    TerminatorInst *TI = Block->getTerminator();
515    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
516      if (!Blocks.count(TI->getSuccessor(i))) {
517        BasicBlock *OldTarget = TI->getSuccessor(i);
518        // add a new basic block which returns the appropriate value
519        BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
520        if (!NewTarget) {
521          // If we don't already have an exit stub for this non-extracted
522          // destination, create one now!
523          NewTarget = BasicBlock::Create(Context,
524                                         OldTarget->getName() + ".exitStub",
525                                         newFunction);
526          unsigned SuccNum = switchVal++;
527
528          Value *brVal = nullptr;
529          switch (NumExitBlocks) {
530          case 0:
531          case 1: break;  // No value needed.
532          case 2:         // Conditional branch, return a bool
533            brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
534            break;
535          default:
536            brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
537            break;
538          }
539
540          ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget);
541
542          // Update the switch instruction.
543          TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
544                                              SuccNum),
545                             OldTarget);
546
547          // Restore values just before we exit
548          Function::arg_iterator OAI = OutputArgBegin;
549          for (unsigned out = 0, e = outputs.size(); out != e; ++out) {
550            // For an invoke, the normal destination is the only one that is
551            // dominated by the result of the invocation
552            BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent();
553
554            bool DominatesDef = true;
555
556            BasicBlock *NormalDest = nullptr;
557            if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out]))
558              NormalDest = Invoke->getNormalDest();
559
560            if (NormalDest) {
561              DefBlock = NormalDest;
562
563              // Make sure we are looking at the original successor block, not
564              // at a newly inserted exit block, which won't be in the dominator
565              // info.
566              for (const auto &I : ExitBlockMap)
567                if (DefBlock == I.second) {
568                  DefBlock = I.first;
569                  break;
570                }
571
572              // In the extract block case, if the block we are extracting ends
573              // with an invoke instruction, make sure that we don't emit a
574              // store of the invoke value for the unwind block.
575              if (!DT && DefBlock != OldTarget)
576                DominatesDef = false;
577            }
578
579            if (DT) {
580              DominatesDef = DT->dominates(DefBlock, OldTarget);
581
582              // If the output value is used by a phi in the target block,
583              // then we need to test for dominance of the phi's predecessor
584              // instead.  Unfortunately, this a little complicated since we
585              // have already rewritten uses of the value to uses of the reload.
586              BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out],
587                                                          OldTarget);
588              if (pred && DT && DT->dominates(DefBlock, pred))
589                DominatesDef = true;
590            }
591
592            if (DominatesDef) {
593              if (AggregateArgs) {
594                Value *Idx[2];
595                Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
596                Idx[1] = ConstantInt::get(Type::getInt32Ty(Context),
597                                          FirstOut+out);
598                GetElementPtrInst *GEP = GetElementPtrInst::Create(
599                    StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(),
600                    NTRet);
601                new StoreInst(outputs[out], GEP, NTRet);
602              } else {
603                new StoreInst(outputs[out], &*OAI, NTRet);
604              }
605            }
606            // Advance output iterator even if we don't emit a store
607            if (!AggregateArgs) ++OAI;
608          }
609        }
610
611        // rewrite the original branch instruction with this new target
612        TI->setSuccessor(i, NewTarget);
613      }
614  }
615
616  // Now that we've done the deed, simplify the switch instruction.
617  Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
618  switch (NumExitBlocks) {
619  case 0:
620    // There are no successors (the block containing the switch itself), which
621    // means that previously this was the last part of the function, and hence
622    // this should be rewritten as a `ret'
623
624    // Check if the function should return a value
625    if (OldFnRetTy->isVoidTy()) {
626      ReturnInst::Create(Context, nullptr, TheSwitch);  // Return void
627    } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
628      // return what we have
629      ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
630    } else {
631      // Otherwise we must have code extracted an unwind or something, just
632      // return whatever we want.
633      ReturnInst::Create(Context,
634                         Constant::getNullValue(OldFnRetTy), TheSwitch);
635    }
636
637    TheSwitch->eraseFromParent();
638    break;
639  case 1:
640    // Only a single destination, change the switch into an unconditional
641    // branch.
642    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
643    TheSwitch->eraseFromParent();
644    break;
645  case 2:
646    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
647                       call, TheSwitch);
648    TheSwitch->eraseFromParent();
649    break;
650  default:
651    // Otherwise, make the default destination of the switch instruction be one
652    // of the other successors.
653    TheSwitch->setCondition(call);
654    TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
655    // Remove redundant case
656    TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
657    break;
658  }
659}
660
661void CodeExtractor::moveCodeToFunction(Function *newFunction) {
662  Function *oldFunc = (*Blocks.begin())->getParent();
663  Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
664  Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
665
666  for (BasicBlock *Block : Blocks) {
667    // Delete the basic block from the old function, and the list of blocks
668    oldBlocks.remove(Block);
669
670    // Insert this basic block into the new function
671    newBlocks.push_back(Block);
672  }
673}
674
675Function *CodeExtractor::extractCodeRegion() {
676  if (!isEligible())
677    return nullptr;
678
679  ValueSet inputs, outputs;
680
681  // Assumption: this is a single-entry code region, and the header is the first
682  // block in the region.
683  BasicBlock *header = *Blocks.begin();
684
685  // If we have to split PHI nodes or the entry block, do so now.
686  severSplitPHINodes(header);
687
688  // If we have any return instructions in the region, split those blocks so
689  // that the return is not in the region.
690  splitReturnBlocks();
691
692  Function *oldFunction = header->getParent();
693
694  // This takes place of the original loop
695  BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
696                                                "codeRepl", oldFunction,
697                                                header);
698
699  // The new function needs a root node because other nodes can branch to the
700  // head of the region, but the entry node of a function cannot have preds.
701  BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
702                                               "newFuncRoot");
703  newFuncRoot->getInstList().push_back(BranchInst::Create(header));
704
705  // Find inputs to, outputs from the code region.
706  findInputsOutputs(inputs, outputs);
707
708  SmallPtrSet<BasicBlock *, 1> ExitBlocks;
709  for (BasicBlock *Block : Blocks)
710    for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
711         ++SI)
712      if (!Blocks.count(*SI))
713        ExitBlocks.insert(*SI);
714  NumExitBlocks = ExitBlocks.size();
715
716  // Construct new function based on inputs/outputs & add allocas for all defs.
717  Function *newFunction = constructFunction(inputs, outputs, header,
718                                            newFuncRoot,
719                                            codeReplacer, oldFunction,
720                                            oldFunction->getParent());
721
722  emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
723
724  moveCodeToFunction(newFunction);
725
726  // Loop over all of the PHI nodes in the header block, and change any
727  // references to the old incoming edge to be the new incoming edge.
728  for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
729    PHINode *PN = cast<PHINode>(I);
730    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
731      if (!Blocks.count(PN->getIncomingBlock(i)))
732        PN->setIncomingBlock(i, newFuncRoot);
733  }
734
735  // Look at all successors of the codeReplacer block.  If any of these blocks
736  // had PHI nodes in them, we need to update the "from" block to be the code
737  // replacer, not the original block in the extracted region.
738  std::vector<BasicBlock*> Succs(succ_begin(codeReplacer),
739                                 succ_end(codeReplacer));
740  for (unsigned i = 0, e = Succs.size(); i != e; ++i)
741    for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) {
742      PHINode *PN = cast<PHINode>(I);
743      std::set<BasicBlock*> ProcessedPreds;
744      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
745        if (Blocks.count(PN->getIncomingBlock(i))) {
746          if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second)
747            PN->setIncomingBlock(i, codeReplacer);
748          else {
749            // There were multiple entries in the PHI for this block, now there
750            // is only one, so remove the duplicated entries.
751            PN->removeIncomingValue(i, false);
752            --i; --e;
753          }
754        }
755    }
756
757  //cerr << "NEW FUNCTION: " << *newFunction;
758  //  verifyFunction(*newFunction);
759
760  //  cerr << "OLD FUNCTION: " << *oldFunction;
761  //  verifyFunction(*oldFunction);
762
763  DEBUG(if (verifyFunction(*newFunction))
764        report_fatal_error("verifyFunction failed!"));
765  return newFunction;
766}
767