CodeExtractor.cpp revision 344779
1//===- CodeExtractor.cpp - Pull code region into a new function -----------===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file implements the interface to tear out a code region, such as an
11// individual loop or a parallel section, into a new function, replacing it with
12// a call to the new function.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/Transforms/Utils/CodeExtractor.h"
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/Optional.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/ADT/SetVector.h"
22#include "llvm/ADT/SmallPtrSet.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/Analysis/BlockFrequencyInfo.h"
25#include "llvm/Analysis/BlockFrequencyInfoImpl.h"
26#include "llvm/Analysis/BranchProbabilityInfo.h"
27#include "llvm/Analysis/LoopInfo.h"
28#include "llvm/IR/Argument.h"
29#include "llvm/IR/Attributes.h"
30#include "llvm/IR/BasicBlock.h"
31#include "llvm/IR/CFG.h"
32#include "llvm/IR/Constant.h"
33#include "llvm/IR/Constants.h"
34#include "llvm/IR/DataLayout.h"
35#include "llvm/IR/DerivedTypes.h"
36#include "llvm/IR/Dominators.h"
37#include "llvm/IR/Function.h"
38#include "llvm/IR/GlobalValue.h"
39#include "llvm/IR/InstrTypes.h"
40#include "llvm/IR/Instruction.h"
41#include "llvm/IR/Instructions.h"
42#include "llvm/IR/IntrinsicInst.h"
43#include "llvm/IR/Intrinsics.h"
44#include "llvm/IR/LLVMContext.h"
45#include "llvm/IR/MDBuilder.h"
46#include "llvm/IR/Module.h"
47#include "llvm/IR/Type.h"
48#include "llvm/IR/User.h"
49#include "llvm/IR/Value.h"
50#include "llvm/IR/Verifier.h"
51#include "llvm/Pass.h"
52#include "llvm/Support/BlockFrequency.h"
53#include "llvm/Support/BranchProbability.h"
54#include "llvm/Support/Casting.h"
55#include "llvm/Support/CommandLine.h"
56#include "llvm/Support/Debug.h"
57#include "llvm/Support/ErrorHandling.h"
58#include "llvm/Support/raw_ostream.h"
59#include "llvm/Transforms/Utils/BasicBlockUtils.h"
60#include "llvm/Transforms/Utils/Local.h"
61#include <cassert>
62#include <cstdint>
63#include <iterator>
64#include <map>
65#include <set>
66#include <utility>
67#include <vector>
68
69using namespace llvm;
70using ProfileCount = Function::ProfileCount;
71
72#define DEBUG_TYPE "code-extractor"
73
74// Provide a command-line option to aggregate function arguments into a struct
75// for functions produced by the code extractor. This is useful when converting
76// extracted functions to pthread-based code, as only one argument (void*) can
77// be passed in to pthread_create().
78static cl::opt<bool>
79AggregateArgsOpt("aggregate-extracted-args", cl::Hidden,
80                 cl::desc("Aggregate arguments to code-extracted functions"));
81
82/// Test whether a block is valid for extraction.
83static bool isBlockValidForExtraction(const BasicBlock &BB,
84                                      const SetVector<BasicBlock *> &Result,
85                                      bool AllowVarArgs, bool AllowAlloca) {
86  // taking the address of a basic block moved to another function is illegal
87  if (BB.hasAddressTaken())
88    return false;
89
90  // don't hoist code that uses another basicblock address, as it's likely to
91  // lead to unexpected behavior, like cross-function jumps
92  SmallPtrSet<User const *, 16> Visited;
93  SmallVector<User const *, 16> ToVisit;
94
95  for (Instruction const &Inst : BB)
96    ToVisit.push_back(&Inst);
97
98  while (!ToVisit.empty()) {
99    User const *Curr = ToVisit.pop_back_val();
100    if (!Visited.insert(Curr).second)
101      continue;
102    if (isa<BlockAddress const>(Curr))
103      return false; // even a reference to self is likely to be not compatible
104
105    if (isa<Instruction>(Curr) && cast<Instruction>(Curr)->getParent() != &BB)
106      continue;
107
108    for (auto const &U : Curr->operands()) {
109      if (auto *UU = dyn_cast<User>(U))
110        ToVisit.push_back(UU);
111    }
112  }
113
114  // If explicitly requested, allow vastart and alloca. For invoke instructions
115  // verify that extraction is valid.
116  for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) {
117    if (isa<AllocaInst>(I)) {
118       if (!AllowAlloca)
119         return false;
120       continue;
121    }
122
123    if (const auto *II = dyn_cast<InvokeInst>(I)) {
124      // Unwind destination (either a landingpad, catchswitch, or cleanuppad)
125      // must be a part of the subgraph which is being extracted.
126      if (auto *UBB = II->getUnwindDest())
127        if (!Result.count(UBB))
128          return false;
129      continue;
130    }
131
132    // All catch handlers of a catchswitch instruction as well as the unwind
133    // destination must be in the subgraph.
134    if (const auto *CSI = dyn_cast<CatchSwitchInst>(I)) {
135      if (auto *UBB = CSI->getUnwindDest())
136        if (!Result.count(UBB))
137          return false;
138      for (auto *HBB : CSI->handlers())
139        if (!Result.count(const_cast<BasicBlock*>(HBB)))
140          return false;
141      continue;
142    }
143
144    // Make sure that entire catch handler is within subgraph. It is sufficient
145    // to check that catch return's block is in the list.
146    if (const auto *CPI = dyn_cast<CatchPadInst>(I)) {
147      for (const auto *U : CPI->users())
148        if (const auto *CRI = dyn_cast<CatchReturnInst>(U))
149          if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
150            return false;
151      continue;
152    }
153
154    // And do similar checks for cleanup handler - the entire handler must be
155    // in subgraph which is going to be extracted. For cleanup return should
156    // additionally check that the unwind destination is also in the subgraph.
157    if (const auto *CPI = dyn_cast<CleanupPadInst>(I)) {
158      for (const auto *U : CPI->users())
159        if (const auto *CRI = dyn_cast<CleanupReturnInst>(U))
160          if (!Result.count(const_cast<BasicBlock*>(CRI->getParent())))
161            return false;
162      continue;
163    }
164    if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) {
165      if (auto *UBB = CRI->getUnwindDest())
166        if (!Result.count(UBB))
167          return false;
168      continue;
169    }
170
171    if (const CallInst *CI = dyn_cast<CallInst>(I)) {
172      if (const Function *F = CI->getCalledFunction()) {
173        auto IID = F->getIntrinsicID();
174        if (IID == Intrinsic::vastart) {
175          if (AllowVarArgs)
176            continue;
177          else
178            return false;
179        }
180
181        // Currently, we miscompile outlined copies of eh_typid_for. There are
182        // proposals for fixing this in llvm.org/PR39545.
183        if (IID == Intrinsic::eh_typeid_for)
184          return false;
185      }
186    }
187  }
188
189  return true;
190}
191
192/// Build a set of blocks to extract if the input blocks are viable.
193static SetVector<BasicBlock *>
194buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
195                        bool AllowVarArgs, bool AllowAlloca) {
196  assert(!BBs.empty() && "The set of blocks to extract must be non-empty");
197  SetVector<BasicBlock *> Result;
198
199  // Loop over the blocks, adding them to our set-vector, and aborting with an
200  // empty set if we encounter invalid blocks.
201  for (BasicBlock *BB : BBs) {
202    // If this block is dead, don't process it.
203    if (DT && !DT->isReachableFromEntry(BB))
204      continue;
205
206    if (!Result.insert(BB))
207      llvm_unreachable("Repeated basic blocks in extraction input");
208  }
209
210  for (auto *BB : Result) {
211    if (!isBlockValidForExtraction(*BB, Result, AllowVarArgs, AllowAlloca))
212      return {};
213
214    // Make sure that the first block is not a landing pad.
215    if (BB == Result.front()) {
216      if (BB->isEHPad()) {
217        LLVM_DEBUG(dbgs() << "The first block cannot be an unwind block\n");
218        return {};
219      }
220      continue;
221    }
222
223    // All blocks other than the first must not have predecessors outside of
224    // the subgraph which is being extracted.
225    for (auto *PBB : predecessors(BB))
226      if (!Result.count(PBB)) {
227        LLVM_DEBUG(
228            dbgs() << "No blocks in this region may have entries from "
229                      "outside the region except for the first block!\n");
230        return {};
231      }
232  }
233
234  return Result;
235}
236
237CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT,
238                             bool AggregateArgs, BlockFrequencyInfo *BFI,
239                             BranchProbabilityInfo *BPI, bool AllowVarArgs,
240                             bool AllowAlloca, std::string Suffix)
241    : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
242      BPI(BPI), AllowVarArgs(AllowVarArgs),
243      Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs, AllowAlloca)),
244      Suffix(Suffix) {}
245
246CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs,
247                             BlockFrequencyInfo *BFI,
248                             BranchProbabilityInfo *BPI, std::string Suffix)
249    : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI),
250      BPI(BPI), AllowVarArgs(false),
251      Blocks(buildExtractionBlockSet(L.getBlocks(), &DT,
252                                     /* AllowVarArgs */ false,
253                                     /* AllowAlloca */ false)),
254      Suffix(Suffix) {}
255
256/// definedInRegion - Return true if the specified value is defined in the
257/// extracted region.
258static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) {
259  if (Instruction *I = dyn_cast<Instruction>(V))
260    if (Blocks.count(I->getParent()))
261      return true;
262  return false;
263}
264
265/// definedInCaller - Return true if the specified value is defined in the
266/// function being code extracted, but not in the region being extracted.
267/// These values must be passed in as live-ins to the function.
268static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) {
269  if (isa<Argument>(V)) return true;
270  if (Instruction *I = dyn_cast<Instruction>(V))
271    if (!Blocks.count(I->getParent()))
272      return true;
273  return false;
274}
275
276static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
277  BasicBlock *CommonExitBlock = nullptr;
278  auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
279    for (auto *Succ : successors(Block)) {
280      // Internal edges, ok.
281      if (Blocks.count(Succ))
282        continue;
283      if (!CommonExitBlock) {
284        CommonExitBlock = Succ;
285        continue;
286      }
287      if (CommonExitBlock == Succ)
288        continue;
289
290      return true;
291    }
292    return false;
293  };
294
295  if (any_of(Blocks, hasNonCommonExitSucc))
296    return nullptr;
297
298  return CommonExitBlock;
299}
300
301bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
302    Instruction *Addr) const {
303  AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
304  Function *Func = (*Blocks.begin())->getParent();
305  for (BasicBlock &BB : *Func) {
306    if (Blocks.count(&BB))
307      continue;
308    for (Instruction &II : BB) {
309      if (isa<DbgInfoIntrinsic>(II))
310        continue;
311
312      unsigned Opcode = II.getOpcode();
313      Value *MemAddr = nullptr;
314      switch (Opcode) {
315      case Instruction::Store:
316      case Instruction::Load: {
317        if (Opcode == Instruction::Store) {
318          StoreInst *SI = cast<StoreInst>(&II);
319          MemAddr = SI->getPointerOperand();
320        } else {
321          LoadInst *LI = cast<LoadInst>(&II);
322          MemAddr = LI->getPointerOperand();
323        }
324        // Global variable can not be aliased with locals.
325        if (dyn_cast<Constant>(MemAddr))
326          break;
327        Value *Base = MemAddr->stripInBoundsConstantOffsets();
328        if (!dyn_cast<AllocaInst>(Base) || Base == AI)
329          return false;
330        break;
331      }
332      default: {
333        IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
334        if (IntrInst) {
335          if (IntrInst->isLifetimeStartOrEnd())
336            break;
337          return false;
338        }
339        // Treat all the other cases conservatively if it has side effects.
340        if (II.mayHaveSideEffects())
341          return false;
342      }
343      }
344    }
345  }
346
347  return true;
348}
349
350BasicBlock *
351CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
352  BasicBlock *SinglePredFromOutlineRegion = nullptr;
353  assert(!Blocks.count(CommonExitBlock) &&
354         "Expect a block outside the region!");
355  for (auto *Pred : predecessors(CommonExitBlock)) {
356    if (!Blocks.count(Pred))
357      continue;
358    if (!SinglePredFromOutlineRegion) {
359      SinglePredFromOutlineRegion = Pred;
360    } else if (SinglePredFromOutlineRegion != Pred) {
361      SinglePredFromOutlineRegion = nullptr;
362      break;
363    }
364  }
365
366  if (SinglePredFromOutlineRegion)
367    return SinglePredFromOutlineRegion;
368
369#ifndef NDEBUG
370  auto getFirstPHI = [](BasicBlock *BB) {
371    BasicBlock::iterator I = BB->begin();
372    PHINode *FirstPhi = nullptr;
373    while (I != BB->end()) {
374      PHINode *Phi = dyn_cast<PHINode>(I);
375      if (!Phi)
376        break;
377      if (!FirstPhi) {
378        FirstPhi = Phi;
379        break;
380      }
381    }
382    return FirstPhi;
383  };
384  // If there are any phi nodes, the single pred either exists or has already
385  // be created before code extraction.
386  assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
387#endif
388
389  BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
390      CommonExitBlock->getFirstNonPHI()->getIterator());
391
392  for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock);
393       PI != PE;) {
394    BasicBlock *Pred = *PI++;
395    if (Blocks.count(Pred))
396      continue;
397    Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
398  }
399  // Now add the old exit block to the outline region.
400  Blocks.insert(CommonExitBlock);
401  return CommonExitBlock;
402}
403
404void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
405                                BasicBlock *&ExitBlock) const {
406  Function *Func = (*Blocks.begin())->getParent();
407  ExitBlock = getCommonExitBlock(Blocks);
408
409  for (BasicBlock &BB : *Func) {
410    if (Blocks.count(&BB))
411      continue;
412    for (Instruction &II : BB) {
413      auto *AI = dyn_cast<AllocaInst>(&II);
414      if (!AI)
415        continue;
416
417      // Find the pair of life time markers for address 'Addr' that are either
418      // defined inside the outline region or can legally be shrinkwrapped into
419      // the outline region. If there are not other untracked uses of the
420      // address, return the pair of markers if found; otherwise return a pair
421      // of nullptr.
422      auto GetLifeTimeMarkers =
423          [&](Instruction *Addr, bool &SinkLifeStart,
424              bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
425        Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
426
427        for (User *U : Addr->users()) {
428          IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
429          if (IntrInst) {
430            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
431              // Do not handle the case where AI has multiple start markers.
432              if (LifeStart)
433                return std::make_pair<Instruction *>(nullptr, nullptr);
434              LifeStart = IntrInst;
435            }
436            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
437              if (LifeEnd)
438                return std::make_pair<Instruction *>(nullptr, nullptr);
439              LifeEnd = IntrInst;
440            }
441            continue;
442          }
443          // Find untracked uses of the address, bail.
444          if (!definedInRegion(Blocks, U))
445            return std::make_pair<Instruction *>(nullptr, nullptr);
446        }
447
448        if (!LifeStart || !LifeEnd)
449          return std::make_pair<Instruction *>(nullptr, nullptr);
450
451        SinkLifeStart = !definedInRegion(Blocks, LifeStart);
452        HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
453        // Do legality Check.
454        if ((SinkLifeStart || HoistLifeEnd) &&
455            !isLegalToShrinkwrapLifetimeMarkers(Addr))
456          return std::make_pair<Instruction *>(nullptr, nullptr);
457
458        // Check to see if we have a place to do hoisting, if not, bail.
459        if (HoistLifeEnd && !ExitBlock)
460          return std::make_pair<Instruction *>(nullptr, nullptr);
461
462        return std::make_pair(LifeStart, LifeEnd);
463      };
464
465      bool SinkLifeStart = false, HoistLifeEnd = false;
466      auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
467
468      if (Markers.first) {
469        if (SinkLifeStart)
470          SinkCands.insert(Markers.first);
471        SinkCands.insert(AI);
472        if (HoistLifeEnd)
473          HoistCands.insert(Markers.second);
474        continue;
475      }
476
477      // Follow the bitcast.
478      Instruction *MarkerAddr = nullptr;
479      for (User *U : AI->users()) {
480        if (U->stripInBoundsConstantOffsets() == AI) {
481          SinkLifeStart = false;
482          HoistLifeEnd = false;
483          Instruction *Bitcast = cast<Instruction>(U);
484          Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
485          if (Markers.first) {
486            MarkerAddr = Bitcast;
487            continue;
488          }
489        }
490
491        // Found unknown use of AI.
492        if (!definedInRegion(Blocks, U)) {
493          MarkerAddr = nullptr;
494          break;
495        }
496      }
497
498      if (MarkerAddr) {
499        if (SinkLifeStart)
500          SinkCands.insert(Markers.first);
501        if (!definedInRegion(Blocks, MarkerAddr))
502          SinkCands.insert(MarkerAddr);
503        SinkCands.insert(AI);
504        if (HoistLifeEnd)
505          HoistCands.insert(Markers.second);
506      }
507    }
508  }
509}
510
511void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
512                                      const ValueSet &SinkCands) const {
513  for (BasicBlock *BB : Blocks) {
514    // If a used value is defined outside the region, it's an input.  If an
515    // instruction is used outside the region, it's an output.
516    for (Instruction &II : *BB) {
517      for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
518           ++OI) {
519        Value *V = *OI;
520        if (!SinkCands.count(V) && definedInCaller(Blocks, V))
521          Inputs.insert(V);
522      }
523
524      for (User *U : II.users())
525        if (!definedInRegion(Blocks, U)) {
526          Outputs.insert(&II);
527          break;
528        }
529    }
530  }
531}
532
533/// severSplitPHINodesOfEntry - If a PHI node has multiple inputs from outside
534/// of the region, we need to split the entry block of the region so that the
535/// PHI node is easier to deal with.
536void CodeExtractor::severSplitPHINodesOfEntry(BasicBlock *&Header) {
537  unsigned NumPredsFromRegion = 0;
538  unsigned NumPredsOutsideRegion = 0;
539
540  if (Header != &Header->getParent()->getEntryBlock()) {
541    PHINode *PN = dyn_cast<PHINode>(Header->begin());
542    if (!PN) return;  // No PHI nodes.
543
544    // If the header node contains any PHI nodes, check to see if there is more
545    // than one entry from outside the region.  If so, we need to sever the
546    // header block into two.
547    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
548      if (Blocks.count(PN->getIncomingBlock(i)))
549        ++NumPredsFromRegion;
550      else
551        ++NumPredsOutsideRegion;
552
553    // If there is one (or fewer) predecessor from outside the region, we don't
554    // need to do anything special.
555    if (NumPredsOutsideRegion <= 1) return;
556  }
557
558  // Otherwise, we need to split the header block into two pieces: one
559  // containing PHI nodes merging values from outside of the region, and a
560  // second that contains all of the code for the block and merges back any
561  // incoming values from inside of the region.
562  BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT);
563
564  // We only want to code extract the second block now, and it becomes the new
565  // header of the region.
566  BasicBlock *OldPred = Header;
567  Blocks.remove(OldPred);
568  Blocks.insert(NewBB);
569  Header = NewBB;
570
571  // Okay, now we need to adjust the PHI nodes and any branches from within the
572  // region to go to the new header block instead of the old header block.
573  if (NumPredsFromRegion) {
574    PHINode *PN = cast<PHINode>(OldPred->begin());
575    // Loop over all of the predecessors of OldPred that are in the region,
576    // changing them to branch to NewBB instead.
577    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
578      if (Blocks.count(PN->getIncomingBlock(i))) {
579        Instruction *TI = PN->getIncomingBlock(i)->getTerminator();
580        TI->replaceUsesOfWith(OldPred, NewBB);
581      }
582
583    // Okay, everything within the region is now branching to the right block, we
584    // just have to update the PHI nodes now, inserting PHI nodes into NewBB.
585    BasicBlock::iterator AfterPHIs;
586    for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) {
587      PHINode *PN = cast<PHINode>(AfterPHIs);
588      // Create a new PHI node in the new region, which has an incoming value
589      // from OldPred of PN.
590      PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion,
591                                       PN->getName() + ".ce", &NewBB->front());
592      PN->replaceAllUsesWith(NewPN);
593      NewPN->addIncoming(PN, OldPred);
594
595      // Loop over all of the incoming value in PN, moving them to NewPN if they
596      // are from the extracted region.
597      for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) {
598        if (Blocks.count(PN->getIncomingBlock(i))) {
599          NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i));
600          PN->removeIncomingValue(i);
601          --i;
602        }
603      }
604    }
605  }
606}
607
608/// severSplitPHINodesOfExits - if PHI nodes in exit blocks have inputs from
609/// outlined region, we split these PHIs on two: one with inputs from region
610/// and other with remaining incoming blocks; then first PHIs are placed in
611/// outlined region.
612void CodeExtractor::severSplitPHINodesOfExits(
613    const SmallPtrSetImpl<BasicBlock *> &Exits) {
614  for (BasicBlock *ExitBB : Exits) {
615    BasicBlock *NewBB = nullptr;
616
617    for (PHINode &PN : ExitBB->phis()) {
618      // Find all incoming values from the outlining region.
619      SmallVector<unsigned, 2> IncomingVals;
620      for (unsigned i = 0; i < PN.getNumIncomingValues(); ++i)
621        if (Blocks.count(PN.getIncomingBlock(i)))
622          IncomingVals.push_back(i);
623
624      // Do not process PHI if there is one (or fewer) predecessor from region.
625      // If PHI has exactly one predecessor from region, only this one incoming
626      // will be replaced on codeRepl block, so it should be safe to skip PHI.
627      if (IncomingVals.size() <= 1)
628        continue;
629
630      // Create block for new PHIs and add it to the list of outlined if it
631      // wasn't done before.
632      if (!NewBB) {
633        NewBB = BasicBlock::Create(ExitBB->getContext(),
634                                   ExitBB->getName() + ".split",
635                                   ExitBB->getParent(), ExitBB);
636        SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBB),
637                                           pred_end(ExitBB));
638        for (BasicBlock *PredBB : Preds)
639          if (Blocks.count(PredBB))
640            PredBB->getTerminator()->replaceUsesOfWith(ExitBB, NewBB);
641        BranchInst::Create(ExitBB, NewBB);
642        Blocks.insert(NewBB);
643      }
644
645      // Split this PHI.
646      PHINode *NewPN =
647          PHINode::Create(PN.getType(), IncomingVals.size(),
648                          PN.getName() + ".ce", NewBB->getFirstNonPHI());
649      for (unsigned i : IncomingVals)
650        NewPN->addIncoming(PN.getIncomingValue(i), PN.getIncomingBlock(i));
651      for (unsigned i : reverse(IncomingVals))
652        PN.removeIncomingValue(i, false);
653      PN.addIncoming(NewPN, NewBB);
654    }
655  }
656}
657
658void CodeExtractor::splitReturnBlocks() {
659  for (BasicBlock *Block : Blocks)
660    if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) {
661      BasicBlock *New =
662          Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret");
663      if (DT) {
664        // Old dominates New. New node dominates all other nodes dominated
665        // by Old.
666        DomTreeNode *OldNode = DT->getNode(Block);
667        SmallVector<DomTreeNode *, 8> Children(OldNode->begin(),
668                                               OldNode->end());
669
670        DomTreeNode *NewNode = DT->addNewBlock(New, Block);
671
672        for (DomTreeNode *I : Children)
673          DT->changeImmediateDominator(I, NewNode);
674      }
675    }
676}
677
678/// constructFunction - make a function based on inputs and outputs, as follows:
679/// f(in0, ..., inN, out0, ..., outN)
680Function *CodeExtractor::constructFunction(const ValueSet &inputs,
681                                           const ValueSet &outputs,
682                                           BasicBlock *header,
683                                           BasicBlock *newRootNode,
684                                           BasicBlock *newHeader,
685                                           Function *oldFunction,
686                                           Module *M) {
687  LLVM_DEBUG(dbgs() << "inputs: " << inputs.size() << "\n");
688  LLVM_DEBUG(dbgs() << "outputs: " << outputs.size() << "\n");
689
690  // This function returns unsigned, outputs will go back by reference.
691  switch (NumExitBlocks) {
692  case 0:
693  case 1: RetTy = Type::getVoidTy(header->getContext()); break;
694  case 2: RetTy = Type::getInt1Ty(header->getContext()); break;
695  default: RetTy = Type::getInt16Ty(header->getContext()); break;
696  }
697
698  std::vector<Type *> paramTy;
699
700  // Add the types of the input values to the function's argument list
701  for (Value *value : inputs) {
702    LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
703    paramTy.push_back(value->getType());
704  }
705
706  // Add the types of the output values to the function's argument list.
707  for (Value *output : outputs) {
708    LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
709    if (AggregateArgs)
710      paramTy.push_back(output->getType());
711    else
712      paramTy.push_back(PointerType::getUnqual(output->getType()));
713  }
714
715  LLVM_DEBUG({
716    dbgs() << "Function type: " << *RetTy << " f(";
717    for (Type *i : paramTy)
718      dbgs() << *i << ", ";
719    dbgs() << ")\n";
720  });
721
722  StructType *StructTy;
723  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
724    StructTy = StructType::get(M->getContext(), paramTy);
725    paramTy.clear();
726    paramTy.push_back(PointerType::getUnqual(StructTy));
727  }
728  FunctionType *funcType =
729                  FunctionType::get(RetTy, paramTy,
730                                    AllowVarArgs && oldFunction->isVarArg());
731
732  std::string SuffixToUse =
733      Suffix.empty()
734          ? (header->getName().empty() ? "extracted" : header->getName().str())
735          : Suffix;
736  // Create the new function
737  Function *newFunction = Function::Create(
738      funcType, GlobalValue::InternalLinkage, oldFunction->getAddressSpace(),
739      oldFunction->getName() + "." + SuffixToUse, M);
740  // If the old function is no-throw, so is the new one.
741  if (oldFunction->doesNotThrow())
742    newFunction->setDoesNotThrow();
743
744  // Inherit the uwtable attribute if we need to.
745  if (oldFunction->hasUWTable())
746    newFunction->setHasUWTable();
747
748  // Inherit all of the target dependent attributes and white-listed
749  // target independent attributes.
750  //  (e.g. If the extracted region contains a call to an x86.sse
751  //  instruction we need to make sure that the extracted region has the
752  //  "target-features" attribute allowing it to be lowered.
753  // FIXME: This should be changed to check to see if a specific
754  //           attribute can not be inherited.
755  for (const auto &Attr : oldFunction->getAttributes().getFnAttributes()) {
756    if (Attr.isStringAttribute()) {
757      if (Attr.getKindAsString() == "thunk")
758        continue;
759    } else
760      switch (Attr.getKindAsEnum()) {
761      // Those attributes cannot be propagated safely. Explicitly list them
762      // here so we get a warning if new attributes are added. This list also
763      // includes non-function attributes.
764      case Attribute::Alignment:
765      case Attribute::AllocSize:
766      case Attribute::ArgMemOnly:
767      case Attribute::Builtin:
768      case Attribute::ByVal:
769      case Attribute::Convergent:
770      case Attribute::Dereferenceable:
771      case Attribute::DereferenceableOrNull:
772      case Attribute::InAlloca:
773      case Attribute::InReg:
774      case Attribute::InaccessibleMemOnly:
775      case Attribute::InaccessibleMemOrArgMemOnly:
776      case Attribute::JumpTable:
777      case Attribute::Naked:
778      case Attribute::Nest:
779      case Attribute::NoAlias:
780      case Attribute::NoBuiltin:
781      case Attribute::NoCapture:
782      case Attribute::NoReturn:
783      case Attribute::None:
784      case Attribute::NonNull:
785      case Attribute::ReadNone:
786      case Attribute::ReadOnly:
787      case Attribute::Returned:
788      case Attribute::ReturnsTwice:
789      case Attribute::SExt:
790      case Attribute::Speculatable:
791      case Attribute::StackAlignment:
792      case Attribute::StructRet:
793      case Attribute::SwiftError:
794      case Attribute::SwiftSelf:
795      case Attribute::WriteOnly:
796      case Attribute::ZExt:
797      case Attribute::EndAttrKinds:
798        continue;
799      // Those attributes should be safe to propagate to the extracted function.
800      case Attribute::AlwaysInline:
801      case Attribute::Cold:
802      case Attribute::NoRecurse:
803      case Attribute::InlineHint:
804      case Attribute::MinSize:
805      case Attribute::NoDuplicate:
806      case Attribute::NoImplicitFloat:
807      case Attribute::NoInline:
808      case Attribute::NonLazyBind:
809      case Attribute::NoRedZone:
810      case Attribute::NoUnwind:
811      case Attribute::OptForFuzzing:
812      case Attribute::OptimizeNone:
813      case Attribute::OptimizeForSize:
814      case Attribute::SafeStack:
815      case Attribute::ShadowCallStack:
816      case Attribute::SanitizeAddress:
817      case Attribute::SanitizeMemory:
818      case Attribute::SanitizeThread:
819      case Attribute::SanitizeHWAddress:
820      case Attribute::SpeculativeLoadHardening:
821      case Attribute::StackProtect:
822      case Attribute::StackProtectReq:
823      case Attribute::StackProtectStrong:
824      case Attribute::StrictFP:
825      case Attribute::UWTable:
826      case Attribute::NoCfCheck:
827        break;
828      }
829
830    newFunction->addFnAttr(Attr);
831  }
832  newFunction->getBasicBlockList().push_back(newRootNode);
833
834  // Create an iterator to name all of the arguments we inserted.
835  Function::arg_iterator AI = newFunction->arg_begin();
836
837  // Rewrite all users of the inputs in the extracted region to use the
838  // arguments (or appropriate addressing into struct) instead.
839  for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
840    Value *RewriteVal;
841    if (AggregateArgs) {
842      Value *Idx[2];
843      Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
844      Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
845      Instruction *TI = newFunction->begin()->getTerminator();
846      GetElementPtrInst *GEP = GetElementPtrInst::Create(
847          StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
848      RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI);
849    } else
850      RewriteVal = &*AI++;
851
852    std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
853    for (User *use : Users)
854      if (Instruction *inst = dyn_cast<Instruction>(use))
855        if (Blocks.count(inst->getParent()))
856          inst->replaceUsesOfWith(inputs[i], RewriteVal);
857  }
858
859  // Set names for input and output arguments.
860  if (!AggregateArgs) {
861    AI = newFunction->arg_begin();
862    for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
863      AI->setName(inputs[i]->getName());
864    for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
865      AI->setName(outputs[i]->getName()+".out");
866  }
867
868  // Rewrite branches to basic blocks outside of the loop to new dummy blocks
869  // within the new function. This must be done before we lose track of which
870  // blocks were originally in the code region.
871  std::vector<User *> Users(header->user_begin(), header->user_end());
872  for (unsigned i = 0, e = Users.size(); i != e; ++i)
873    // The BasicBlock which contains the branch is not in the region
874    // modify the branch target to a new block
875    if (Instruction *I = dyn_cast<Instruction>(Users[i]))
876      if (I->isTerminator() && !Blocks.count(I->getParent()) &&
877          I->getParent()->getParent() == oldFunction)
878        I->replaceUsesOfWith(header, newHeader);
879
880  return newFunction;
881}
882
883/// emitCallAndSwitchStatement - This method sets up the caller side by adding
884/// the call instruction, splitting any PHI nodes in the header block as
885/// necessary.
886CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
887                                                    BasicBlock *codeReplacer,
888                                                    ValueSet &inputs,
889                                                    ValueSet &outputs) {
890  // Emit a call to the new function, passing in: *pointer to struct (if
891  // aggregating parameters), or plan inputs and allocated memory for outputs
892  std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
893
894  Module *M = newFunction->getParent();
895  LLVMContext &Context = M->getContext();
896  const DataLayout &DL = M->getDataLayout();
897  CallInst *call = nullptr;
898
899  // Add inputs as params, or to be filled into the struct
900  for (Value *input : inputs)
901    if (AggregateArgs)
902      StructValues.push_back(input);
903    else
904      params.push_back(input);
905
906  // Create allocas for the outputs
907  for (Value *output : outputs) {
908    if (AggregateArgs) {
909      StructValues.push_back(output);
910    } else {
911      AllocaInst *alloca =
912        new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
913                       nullptr, output->getName() + ".loc",
914                       &codeReplacer->getParent()->front().front());
915      ReloadOutputs.push_back(alloca);
916      params.push_back(alloca);
917    }
918  }
919
920  StructType *StructArgTy = nullptr;
921  AllocaInst *Struct = nullptr;
922  if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
923    std::vector<Type *> ArgTypes;
924    for (ValueSet::iterator v = StructValues.begin(),
925           ve = StructValues.end(); v != ve; ++v)
926      ArgTypes.push_back((*v)->getType());
927
928    // Allocate a struct at the beginning of this function
929    StructArgTy = StructType::get(newFunction->getContext(), ArgTypes);
930    Struct = new AllocaInst(StructArgTy, DL.getAllocaAddrSpace(), nullptr,
931                            "structArg",
932                            &codeReplacer->getParent()->front().front());
933    params.push_back(Struct);
934
935    for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
936      Value *Idx[2];
937      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
938      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
939      GetElementPtrInst *GEP = GetElementPtrInst::Create(
940          StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
941      codeReplacer->getInstList().push_back(GEP);
942      StoreInst *SI = new StoreInst(StructValues[i], GEP);
943      codeReplacer->getInstList().push_back(SI);
944    }
945  }
946
947  // Emit the call to the function
948  call = CallInst::Create(newFunction, params,
949                          NumExitBlocks > 1 ? "targetBlock" : "");
950  // Add debug location to the new call, if the original function has debug
951  // info. In that case, the terminator of the entry block of the extracted
952  // function contains the first debug location of the extracted function,
953  // set in extractCodeRegion.
954  if (codeReplacer->getParent()->getSubprogram()) {
955    if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc())
956      call->setDebugLoc(DL);
957  }
958  codeReplacer->getInstList().push_back(call);
959
960  Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
961  unsigned FirstOut = inputs.size();
962  if (!AggregateArgs)
963    std::advance(OutputArgBegin, inputs.size());
964
965  // Reload the outputs passed in by reference.
966  Function::arg_iterator OAI = OutputArgBegin;
967  for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
968    Value *Output = nullptr;
969    if (AggregateArgs) {
970      Value *Idx[2];
971      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
972      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
973      GetElementPtrInst *GEP = GetElementPtrInst::Create(
974          StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
975      codeReplacer->getInstList().push_back(GEP);
976      Output = GEP;
977    } else {
978      Output = ReloadOutputs[i];
979    }
980    LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload");
981    Reloads.push_back(load);
982    codeReplacer->getInstList().push_back(load);
983    std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end());
984    for (unsigned u = 0, e = Users.size(); u != e; ++u) {
985      Instruction *inst = cast<Instruction>(Users[u]);
986      if (!Blocks.count(inst->getParent()))
987        inst->replaceUsesOfWith(outputs[i], load);
988    }
989
990    // Store to argument right after the definition of output value.
991    auto *OutI = dyn_cast<Instruction>(outputs[i]);
992    if (!OutI)
993      continue;
994
995    // Find proper insertion point.
996    BasicBlock::iterator InsertPt;
997    // In case OutI is an invoke, we insert the store at the beginning in the
998    // 'normal destination' BB. Otherwise we insert the store right after OutI.
999    if (auto *InvokeI = dyn_cast<InvokeInst>(OutI))
1000      InsertPt = InvokeI->getNormalDest()->getFirstInsertionPt();
1001    else if (auto *Phi = dyn_cast<PHINode>(OutI))
1002      InsertPt = Phi->getParent()->getFirstInsertionPt();
1003    else
1004      InsertPt = std::next(OutI->getIterator());
1005
1006    assert(OAI != newFunction->arg_end() &&
1007           "Number of output arguments should match "
1008           "the amount of defined values");
1009    if (AggregateArgs) {
1010      Value *Idx[2];
1011      Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
1012      Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
1013      GetElementPtrInst *GEP = GetElementPtrInst::Create(
1014          StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), &*InsertPt);
1015      new StoreInst(outputs[i], GEP, &*InsertPt);
1016      // Since there should be only one struct argument aggregating
1017      // all the output values, we shouldn't increment OAI, which always
1018      // points to the struct argument, in this case.
1019    } else {
1020      new StoreInst(outputs[i], &*OAI, &*InsertPt);
1021      ++OAI;
1022    }
1023  }
1024
1025  // Now we can emit a switch statement using the call as a value.
1026  SwitchInst *TheSwitch =
1027      SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)),
1028                         codeReplacer, 0, codeReplacer);
1029
1030  // Since there may be multiple exits from the original region, make the new
1031  // function return an unsigned, switch on that number.  This loop iterates
1032  // over all of the blocks in the extracted region, updating any terminator
1033  // instructions in the to-be-extracted region that branch to blocks that are
1034  // not in the region to be extracted.
1035  std::map<BasicBlock *, BasicBlock *> ExitBlockMap;
1036
1037  unsigned switchVal = 0;
1038  for (BasicBlock *Block : Blocks) {
1039    Instruction *TI = Block->getTerminator();
1040    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
1041      if (!Blocks.count(TI->getSuccessor(i))) {
1042        BasicBlock *OldTarget = TI->getSuccessor(i);
1043        // add a new basic block which returns the appropriate value
1044        BasicBlock *&NewTarget = ExitBlockMap[OldTarget];
1045        if (!NewTarget) {
1046          // If we don't already have an exit stub for this non-extracted
1047          // destination, create one now!
1048          NewTarget = BasicBlock::Create(Context,
1049                                         OldTarget->getName() + ".exitStub",
1050                                         newFunction);
1051          unsigned SuccNum = switchVal++;
1052
1053          Value *brVal = nullptr;
1054          switch (NumExitBlocks) {
1055          case 0:
1056          case 1: break;  // No value needed.
1057          case 2:         // Conditional branch, return a bool
1058            brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum);
1059            break;
1060          default:
1061            brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum);
1062            break;
1063          }
1064
1065          ReturnInst::Create(Context, brVal, NewTarget);
1066
1067          // Update the switch instruction.
1068          TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context),
1069                                              SuccNum),
1070                             OldTarget);
1071        }
1072
1073        // rewrite the original branch instruction with this new target
1074        TI->setSuccessor(i, NewTarget);
1075      }
1076  }
1077
1078  // Now that we've done the deed, simplify the switch instruction.
1079  Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType();
1080  switch (NumExitBlocks) {
1081  case 0:
1082    // There are no successors (the block containing the switch itself), which
1083    // means that previously this was the last part of the function, and hence
1084    // this should be rewritten as a `ret'
1085
1086    // Check if the function should return a value
1087    if (OldFnRetTy->isVoidTy()) {
1088      ReturnInst::Create(Context, nullptr, TheSwitch);  // Return void
1089    } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) {
1090      // return what we have
1091      ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch);
1092    } else {
1093      // Otherwise we must have code extracted an unwind or something, just
1094      // return whatever we want.
1095      ReturnInst::Create(Context,
1096                         Constant::getNullValue(OldFnRetTy), TheSwitch);
1097    }
1098
1099    TheSwitch->eraseFromParent();
1100    break;
1101  case 1:
1102    // Only a single destination, change the switch into an unconditional
1103    // branch.
1104    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch);
1105    TheSwitch->eraseFromParent();
1106    break;
1107  case 2:
1108    BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2),
1109                       call, TheSwitch);
1110    TheSwitch->eraseFromParent();
1111    break;
1112  default:
1113    // Otherwise, make the default destination of the switch instruction be one
1114    // of the other successors.
1115    TheSwitch->setCondition(call);
1116    TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks));
1117    // Remove redundant case
1118    TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1));
1119    break;
1120  }
1121
1122  return call;
1123}
1124
1125void CodeExtractor::moveCodeToFunction(Function *newFunction) {
1126  Function *oldFunc = (*Blocks.begin())->getParent();
1127  Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
1128  Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
1129
1130  for (BasicBlock *Block : Blocks) {
1131    // Delete the basic block from the old function, and the list of blocks
1132    oldBlocks.remove(Block);
1133
1134    // Insert this basic block into the new function
1135    newBlocks.push_back(Block);
1136  }
1137}
1138
1139void CodeExtractor::calculateNewCallTerminatorWeights(
1140    BasicBlock *CodeReplacer,
1141    DenseMap<BasicBlock *, BlockFrequency> &ExitWeights,
1142    BranchProbabilityInfo *BPI) {
1143  using Distribution = BlockFrequencyInfoImplBase::Distribution;
1144  using BlockNode = BlockFrequencyInfoImplBase::BlockNode;
1145
1146  // Update the branch weights for the exit block.
1147  Instruction *TI = CodeReplacer->getTerminator();
1148  SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0);
1149
1150  // Block Frequency distribution with dummy node.
1151  Distribution BranchDist;
1152
1153  // Add each of the frequencies of the successors.
1154  for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) {
1155    BlockNode ExitNode(i);
1156    uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency();
1157    if (ExitFreq != 0)
1158      BranchDist.addExit(ExitNode, ExitFreq);
1159    else
1160      BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero());
1161  }
1162
1163  // Check for no total weight.
1164  if (BranchDist.Total == 0)
1165    return;
1166
1167  // Normalize the distribution so that they can fit in unsigned.
1168  BranchDist.normalize();
1169
1170  // Create normalized branch weights and set the metadata.
1171  for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) {
1172    const auto &Weight = BranchDist.Weights[I];
1173
1174    // Get the weight and update the current BFI.
1175    BranchWeights[Weight.TargetNode.Index] = Weight.Amount;
1176    BranchProbability BP(Weight.Amount, BranchDist.Total);
1177    BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP);
1178  }
1179  TI->setMetadata(
1180      LLVMContext::MD_prof,
1181      MDBuilder(TI->getContext()).createBranchWeights(BranchWeights));
1182}
1183
1184/// Scan the extraction region for lifetime markers which reference inputs.
1185/// Erase these markers. Return the inputs which were referenced.
1186///
1187/// The extraction region is defined by a set of blocks (\p Blocks), and a set
1188/// of allocas which will be moved from the caller function into the extracted
1189/// function (\p SunkAllocas).
1190static SetVector<Value *>
1191eraseLifetimeMarkersOnInputs(const SetVector<BasicBlock *> &Blocks,
1192                             const SetVector<Value *> &SunkAllocas) {
1193  SetVector<Value *> InputObjectsWithLifetime;
1194  for (BasicBlock *BB : Blocks) {
1195    for (auto It = BB->begin(), End = BB->end(); It != End;) {
1196      auto *II = dyn_cast<IntrinsicInst>(&*It);
1197      ++It;
1198      if (!II || !II->isLifetimeStartOrEnd())
1199        continue;
1200
1201      // Get the memory operand of the lifetime marker. If the underlying
1202      // object is a sunk alloca, or is otherwise defined in the extraction
1203      // region, the lifetime marker must not be erased.
1204      Value *Mem = II->getOperand(1)->stripInBoundsOffsets();
1205      if (SunkAllocas.count(Mem) || definedInRegion(Blocks, Mem))
1206        continue;
1207
1208      InputObjectsWithLifetime.insert(Mem);
1209      II->eraseFromParent();
1210    }
1211  }
1212  return InputObjectsWithLifetime;
1213}
1214
1215/// Insert lifetime start/end markers surrounding the call to the new function
1216/// for objects defined in the caller.
1217static void insertLifetimeMarkersSurroundingCall(
1218    Module *M, const SetVector<Value *> &InputObjectsWithLifetime,
1219    CallInst *TheCall) {
1220  if (InputObjectsWithLifetime.empty())
1221    return;
1222
1223  LLVMContext &Ctx = M->getContext();
1224  auto Int8PtrTy = Type::getInt8PtrTy(Ctx);
1225  auto NegativeOne = ConstantInt::getSigned(Type::getInt64Ty(Ctx), -1);
1226  auto LifetimeStartFn = llvm::Intrinsic::getDeclaration(
1227      M, llvm::Intrinsic::lifetime_start, Int8PtrTy);
1228  auto LifetimeEndFn = llvm::Intrinsic::getDeclaration(
1229      M, llvm::Intrinsic::lifetime_end, Int8PtrTy);
1230  for (Value *Mem : InputObjectsWithLifetime) {
1231    assert((!isa<Instruction>(Mem) ||
1232            cast<Instruction>(Mem)->getFunction() == TheCall->getFunction()) &&
1233           "Input memory not defined in original function");
1234    Value *MemAsI8Ptr = nullptr;
1235    if (Mem->getType() == Int8PtrTy)
1236      MemAsI8Ptr = Mem;
1237    else
1238      MemAsI8Ptr =
1239          CastInst::CreatePointerCast(Mem, Int8PtrTy, "lt.cast", TheCall);
1240
1241    auto StartMarker =
1242        CallInst::Create(LifetimeStartFn, {NegativeOne, MemAsI8Ptr});
1243    StartMarker->insertBefore(TheCall);
1244    auto EndMarker = CallInst::Create(LifetimeEndFn, {NegativeOne, MemAsI8Ptr});
1245    EndMarker->insertAfter(TheCall);
1246  }
1247}
1248
1249Function *CodeExtractor::extractCodeRegion() {
1250  if (!isEligible())
1251    return nullptr;
1252
1253  // Assumption: this is a single-entry code region, and the header is the first
1254  // block in the region.
1255  BasicBlock *header = *Blocks.begin();
1256  Function *oldFunction = header->getParent();
1257
1258  // For functions with varargs, check that varargs handling is only done in the
1259  // outlined function, i.e vastart and vaend are only used in outlined blocks.
1260  if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) {
1261    auto containsVarArgIntrinsic = [](Instruction &I) {
1262      if (const CallInst *CI = dyn_cast<CallInst>(&I))
1263        if (const Function *F = CI->getCalledFunction())
1264          return F->getIntrinsicID() == Intrinsic::vastart ||
1265                 F->getIntrinsicID() == Intrinsic::vaend;
1266      return false;
1267    };
1268
1269    for (auto &BB : *oldFunction) {
1270      if (Blocks.count(&BB))
1271        continue;
1272      if (llvm::any_of(BB, containsVarArgIntrinsic))
1273        return nullptr;
1274    }
1275  }
1276  ValueSet inputs, outputs, SinkingCands, HoistingCands;
1277  BasicBlock *CommonExit = nullptr;
1278
1279  // Calculate the entry frequency of the new function before we change the root
1280  //   block.
1281  BlockFrequency EntryFreq;
1282  if (BFI) {
1283    assert(BPI && "Both BPI and BFI are required to preserve profile info");
1284    for (BasicBlock *Pred : predecessors(header)) {
1285      if (Blocks.count(Pred))
1286        continue;
1287      EntryFreq +=
1288          BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header);
1289    }
1290  }
1291
1292  // If we have any return instructions in the region, split those blocks so
1293  // that the return is not in the region.
1294  splitReturnBlocks();
1295
1296  // Calculate the exit blocks for the extracted region and the total exit
1297  // weights for each of those blocks.
1298  DenseMap<BasicBlock *, BlockFrequency> ExitWeights;
1299  SmallPtrSet<BasicBlock *, 1> ExitBlocks;
1300  for (BasicBlock *Block : Blocks) {
1301    for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE;
1302         ++SI) {
1303      if (!Blocks.count(*SI)) {
1304        // Update the branch weight for this successor.
1305        if (BFI) {
1306          BlockFrequency &BF = ExitWeights[*SI];
1307          BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI);
1308        }
1309        ExitBlocks.insert(*SI);
1310      }
1311    }
1312  }
1313  NumExitBlocks = ExitBlocks.size();
1314
1315  // If we have to split PHI nodes of the entry or exit blocks, do so now.
1316  severSplitPHINodesOfEntry(header);
1317  severSplitPHINodesOfExits(ExitBlocks);
1318
1319  // This takes place of the original loop
1320  BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(),
1321                                                "codeRepl", oldFunction,
1322                                                header);
1323
1324  // The new function needs a root node because other nodes can branch to the
1325  // head of the region, but the entry node of a function cannot have preds.
1326  BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(),
1327                                               "newFuncRoot");
1328  auto *BranchI = BranchInst::Create(header);
1329  // If the original function has debug info, we have to add a debug location
1330  // to the new branch instruction from the artificial entry block.
1331  // We use the debug location of the first instruction in the extracted
1332  // blocks, as there is no other equivalent line in the source code.
1333  if (oldFunction->getSubprogram()) {
1334    any_of(Blocks, [&BranchI](const BasicBlock *BB) {
1335      return any_of(*BB, [&BranchI](const Instruction &I) {
1336        if (!I.getDebugLoc())
1337          return false;
1338        BranchI->setDebugLoc(I.getDebugLoc());
1339        return true;
1340      });
1341    });
1342  }
1343  newFuncRoot->getInstList().push_back(BranchI);
1344
1345  findAllocas(SinkingCands, HoistingCands, CommonExit);
1346  assert(HoistingCands.empty() || CommonExit);
1347
1348  // Find inputs to, outputs from the code region.
1349  findInputsOutputs(inputs, outputs, SinkingCands);
1350
1351  // Now sink all instructions which only have non-phi uses inside the region
1352  for (auto *II : SinkingCands)
1353    cast<Instruction>(II)->moveBefore(*newFuncRoot,
1354                                      newFuncRoot->getFirstInsertionPt());
1355
1356  if (!HoistingCands.empty()) {
1357    auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
1358    Instruction *TI = HoistToBlock->getTerminator();
1359    for (auto *II : HoistingCands)
1360      cast<Instruction>(II)->moveBefore(TI);
1361  }
1362
1363  // Collect objects which are inputs to the extraction region and also
1364  // referenced by lifetime start/end markers within it. The effects of these
1365  // markers must be replicated in the calling function to prevent the stack
1366  // coloring pass from merging slots which store input objects.
1367  ValueSet InputObjectsWithLifetime =
1368      eraseLifetimeMarkersOnInputs(Blocks, SinkingCands);
1369
1370  // Construct new function based on inputs/outputs & add allocas for all defs.
1371  Function *newFunction =
1372      constructFunction(inputs, outputs, header, newFuncRoot, codeReplacer,
1373                        oldFunction, oldFunction->getParent());
1374
1375  // Update the entry count of the function.
1376  if (BFI) {
1377    auto Count = BFI->getProfileCountFromFreq(EntryFreq.getFrequency());
1378    if (Count.hasValue())
1379      newFunction->setEntryCount(
1380          ProfileCount(Count.getValue(), Function::PCT_Real)); // FIXME
1381    BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency());
1382  }
1383
1384  CallInst *TheCall =
1385      emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
1386
1387  moveCodeToFunction(newFunction);
1388
1389  // Replicate the effects of any lifetime start/end markers which referenced
1390  // input objects in the extraction region by placing markers around the call.
1391  insertLifetimeMarkersSurroundingCall(oldFunction->getParent(),
1392                                       InputObjectsWithLifetime, TheCall);
1393
1394  // Propagate personality info to the new function if there is one.
1395  if (oldFunction->hasPersonalityFn())
1396    newFunction->setPersonalityFn(oldFunction->getPersonalityFn());
1397
1398  // Update the branch weights for the exit block.
1399  if (BFI && NumExitBlocks > 1)
1400    calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI);
1401
1402  // Loop over all of the PHI nodes in the header and exit blocks, and change
1403  // any references to the old incoming edge to be the new incoming edge.
1404  for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) {
1405    PHINode *PN = cast<PHINode>(I);
1406    for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
1407      if (!Blocks.count(PN->getIncomingBlock(i)))
1408        PN->setIncomingBlock(i, newFuncRoot);
1409  }
1410
1411  for (BasicBlock *ExitBB : ExitBlocks)
1412    for (PHINode &PN : ExitBB->phis()) {
1413      Value *IncomingCodeReplacerVal = nullptr;
1414      for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) {
1415        // Ignore incoming values from outside of the extracted region.
1416        if (!Blocks.count(PN.getIncomingBlock(i)))
1417          continue;
1418
1419        // Ensure that there is only one incoming value from codeReplacer.
1420        if (!IncomingCodeReplacerVal) {
1421          PN.setIncomingBlock(i, codeReplacer);
1422          IncomingCodeReplacerVal = PN.getIncomingValue(i);
1423        } else
1424          assert(IncomingCodeReplacerVal == PN.getIncomingValue(i) &&
1425                 "PHI has two incompatbile incoming values from codeRepl");
1426      }
1427    }
1428
1429  // Erase debug info intrinsics. Variable updates within the new function are
1430  // invisible to debuggers. This could be improved by defining a DISubprogram
1431  // for the new function.
1432  for (BasicBlock &BB : *newFunction) {
1433    auto BlockIt = BB.begin();
1434    // Remove debug info intrinsics from the new function.
1435    while (BlockIt != BB.end()) {
1436      Instruction *Inst = &*BlockIt;
1437      ++BlockIt;
1438      if (isa<DbgInfoIntrinsic>(Inst))
1439        Inst->eraseFromParent();
1440    }
1441    // Remove debug info intrinsics which refer to values in the new function
1442    // from the old function.
1443    SmallVector<DbgVariableIntrinsic *, 4> DbgUsers;
1444    for (Instruction &I : BB)
1445      findDbgUsers(DbgUsers, &I);
1446    for (DbgVariableIntrinsic *DVI : DbgUsers)
1447      DVI->eraseFromParent();
1448  }
1449
1450  // Mark the new function `noreturn` if applicable. Terminators which resume
1451  // exception propagation are treated as returning instructions. This is to
1452  // avoid inserting traps after calls to outlined functions which unwind.
1453  bool doesNotReturn = none_of(*newFunction, [](const BasicBlock &BB) {
1454    const Instruction *Term = BB.getTerminator();
1455    return isa<ReturnInst>(Term) || isa<ResumeInst>(Term);
1456  });
1457  if (doesNotReturn)
1458    newFunction->setDoesNotReturn();
1459
1460  LLVM_DEBUG(if (verifyFunction(*newFunction, &errs())) {
1461    newFunction->dump();
1462    report_fatal_error("verification of newFunction failed!");
1463  });
1464  LLVM_DEBUG(if (verifyFunction(*oldFunction))
1465             report_fatal_error("verification of oldFunction failed!"));
1466  return newFunction;
1467}
1468