CodeExtractor.cpp revision 314564
1169689Skan//===- CodeExtractor.cpp - Pull code region into a new function -----------===// 2169689Skan// 3169689Skan// The LLVM Compiler Infrastructure 4169689Skan// 5169689Skan// This file is distributed under the University of Illinois Open Source 6169689Skan// License. See LICENSE.TXT for details. 7169689Skan// 8169689Skan//===----------------------------------------------------------------------===// 9169689Skan// 10169689Skan// This file implements the interface to tear out a code region, such as an 11169689Skan// individual loop or a parallel section, into a new function, replacing it with 12169689Skan// a call to the new function. 13169689Skan// 14169689Skan//===----------------------------------------------------------------------===// 15169689Skan 16169689Skan#include "llvm/Transforms/Utils/CodeExtractor.h" 17169689Skan#include "llvm/ADT/STLExtras.h" 18169689Skan#include "llvm/ADT/SetVector.h" 19169689Skan#include "llvm/ADT/StringExtras.h" 20169689Skan#include "llvm/Analysis/BlockFrequencyInfo.h" 21169689Skan#include "llvm/Analysis/BlockFrequencyInfoImpl.h" 22169689Skan#include "llvm/Analysis/BranchProbabilityInfo.h" 23169689Skan#include "llvm/Analysis/LoopInfo.h" 24169689Skan#include "llvm/Analysis/RegionInfo.h" 25169689Skan#include "llvm/Analysis/RegionIterator.h" 26169689Skan#include "llvm/IR/Constants.h" 27169689Skan#include "llvm/IR/DerivedTypes.h" 28169689Skan#include "llvm/IR/Dominators.h" 29169689Skan#include "llvm/IR/Instructions.h" 30169689Skan#include "llvm/IR/Intrinsics.h" 31169689Skan#include "llvm/IR/LLVMContext.h" 32169689Skan#include "llvm/IR/MDBuilder.h" 33169689Skan#include "llvm/IR/Module.h" 34169689Skan#include "llvm/IR/Verifier.h" 35169689Skan#include "llvm/Pass.h" 36169689Skan#include "llvm/Support/BlockFrequency.h" 37169689Skan#include "llvm/Support/CommandLine.h" 38169689Skan#include "llvm/Support/Debug.h" 39169689Skan#include "llvm/Support/ErrorHandling.h" 40169689Skan#include "llvm/Support/raw_ostream.h" 41169689Skan#include "llvm/Transforms/Utils/BasicBlockUtils.h" 42169689Skan#include <algorithm> 43169689Skan#include <set> 44169689Skanusing namespace llvm; 45169689Skan 46169689Skan#define DEBUG_TYPE "code-extractor" 47169689Skan 48169689Skan// Provide a command-line option to aggregate function arguments into a struct 49169689Skan// for functions produced by the code extractor. This is useful when converting 50169689Skan// extracted functions to pthread-based code, as only one argument (void*) can 51169689Skan// be passed in to pthread_create(). 52169689Skanstatic cl::opt<bool> 53169689SkanAggregateArgsOpt("aggregate-extracted-args", cl::Hidden, 54169689Skan cl::desc("Aggregate arguments to code-extracted functions")); 55169689Skan 56169689Skan/// \brief Test whether a block is valid for extraction. 57169689Skanbool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { 58169689Skan // Landing pads must be in the function where they were inserted for cleanup. 59169689Skan if (BB.isEHPad()) 60169689Skan return false; 61169689Skan 62169689Skan // Don't hoist code containing allocas, invokes, or vastarts. 63169689Skan for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 64169689Skan if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) 65169689Skan return false; 66169689Skan if (const CallInst *CI = dyn_cast<CallInst>(I)) 67169689Skan if (const Function *F = CI->getCalledFunction()) 68169689Skan if (F->getIntrinsicID() == Intrinsic::vastart) 69169689Skan return false; 70169689Skan } 71169689Skan 72169689Skan return true; 73169689Skan} 74169689Skan 75169689Skan/// \brief Build a set of blocks to extract if the input blocks are viable. 76169689Skantemplate <typename IteratorT> 77169689Skanstatic SetVector<BasicBlock *> buildExtractionBlockSet(IteratorT BBBegin, 78169689Skan IteratorT BBEnd) { 79169689Skan SetVector<BasicBlock *> Result; 80169689Skan 81169689Skan assert(BBBegin != BBEnd); 82169689Skan 83169689Skan // Loop over the blocks, adding them to our set-vector, and aborting with an 84169689Skan // empty set if we encounter invalid blocks. 85169689Skan do { 86169689Skan if (!Result.insert(*BBBegin)) 87169689Skan llvm_unreachable("Repeated basic blocks in extraction input"); 88169689Skan 89169689Skan if (!CodeExtractor::isBlockValidForExtraction(**BBBegin)) { 90169689Skan Result.clear(); 91169689Skan return Result; 92169689Skan } 93169689Skan } while (++BBBegin != BBEnd); 94169689Skan 95169689Skan#ifndef NDEBUG 96169689Skan for (SetVector<BasicBlock *>::iterator I = std::next(Result.begin()), 97169689Skan E = Result.end(); 98169689Skan I != E; ++I) 99169689Skan for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); 100169689Skan PI != PE; ++PI) 101169689Skan assert(Result.count(*PI) && 102169689Skan "No blocks in this region may have entries from outside the region" 103169689Skan " except for the first block!"); 104169689Skan#endif 105169689Skan 106169689Skan return Result; 107169689Skan} 108169689Skan 109169689Skan/// \brief Helper to call buildExtractionBlockSet with an ArrayRef. 110169689Skanstatic SetVector<BasicBlock *> 111169689SkanbuildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) { 112169689Skan return buildExtractionBlockSet(BBs.begin(), BBs.end()); 113169689Skan} 114169689Skan 115169689Skan/// \brief Helper to call buildExtractionBlockSet with a RegionNode. 116169689Skanstatic SetVector<BasicBlock *> 117169689SkanbuildExtractionBlockSet(const RegionNode &RN) { 118169689Skan if (!RN.isSubRegion()) 119169689Skan // Just a single BasicBlock. 120169689Skan return buildExtractionBlockSet(RN.getNodeAs<BasicBlock>()); 121169689Skan 122169689Skan const Region &R = *RN.getNodeAs<Region>(); 123169689Skan 124169689Skan return buildExtractionBlockSet(R.block_begin(), R.block_end()); 125169689Skan} 126169689Skan 127169689SkanCodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs, 128169689Skan BlockFrequencyInfo *BFI, 129169689Skan BranchProbabilityInfo *BPI) 130169689Skan : DT(nullptr), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), 131169689Skan BPI(BPI), Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} 132169689Skan 133169689SkanCodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, 134169689Skan bool AggregateArgs, BlockFrequencyInfo *BFI, 135169689Skan BranchProbabilityInfo *BPI) 136169689Skan : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), 137169689Skan BPI(BPI), Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} 138169689Skan 139169689SkanCodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, 140169689Skan BlockFrequencyInfo *BFI, 141169689Skan BranchProbabilityInfo *BPI) 142169689Skan : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), 143169689Skan BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks())), 144169689Skan NumExitBlocks(~0U) {} 145169689Skan 146169689SkanCodeExtractor::CodeExtractor(DominatorTree &DT, const RegionNode &RN, 147169689Skan bool AggregateArgs, BlockFrequencyInfo *BFI, 148169689Skan BranchProbabilityInfo *BPI) 149169689Skan : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), 150169689Skan BPI(BPI), Blocks(buildExtractionBlockSet(RN)), NumExitBlocks(~0U) {} 151169689Skan 152169689Skan/// definedInRegion - Return true if the specified value is defined in the 153169689Skan/// extracted region. 154169689Skanstatic bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) { 155169689Skan if (Instruction *I = dyn_cast<Instruction>(V)) 156169689Skan if (Blocks.count(I->getParent())) 157169689Skan return true; 158169689Skan return false; 159169689Skan} 160169689Skan 161169689Skan/// definedInCaller - Return true if the specified value is defined in the 162169689Skan/// function being code extracted, but not in the region being extracted. 163169689Skan/// These values must be passed in as live-ins to the function. 164169689Skanstatic bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { 165169689Skan if (isa<Argument>(V)) return true; 166169689Skan if (Instruction *I = dyn_cast<Instruction>(V)) 167169689Skan if (!Blocks.count(I->getParent())) 168169689Skan return true; 169169689Skan return false; 170169689Skan} 171169689Skan 172169689Skanvoid CodeExtractor::findInputsOutputs(ValueSet &Inputs, 173169689Skan ValueSet &Outputs) const { 174169689Skan for (BasicBlock *BB : Blocks) { 175169689Skan // If a used value is defined outside the region, it's an input. If an 176169689Skan // instruction is used outside the region, it's an output. 177169689Skan for (Instruction &II : *BB) { 178169689Skan for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE; 179169689Skan ++OI) 180169689Skan if (definedInCaller(Blocks, *OI)) 181169689Skan Inputs.insert(*OI); 182169689Skan 183169689Skan for (User *U : II.users()) 184169689Skan if (!definedInRegion(Blocks, U)) { 185169689Skan Outputs.insert(&II); 186169689Skan break; 187169689Skan } 188169689Skan } 189169689Skan } 190169689Skan} 191169689Skan 192169689Skan/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the 193169689Skan/// region, we need to split the entry block of the region so that the PHI node 194169689Skan/// is easier to deal with. 195169689Skanvoid CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { 196169689Skan unsigned NumPredsFromRegion = 0; 197169689Skan unsigned NumPredsOutsideRegion = 0; 198169689Skan 199169689Skan if (Header != &Header->getParent()->getEntryBlock()) { 200169689Skan PHINode *PN = dyn_cast<PHINode>(Header->begin()); 201169689Skan if (!PN) return; // No PHI nodes. 202169689Skan 203169689Skan // If the header node contains any PHI nodes, check to see if there is more 204169689Skan // than one entry from outside the region. If so, we need to sever the 205169689Skan // header block into two. 206169689Skan for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 207169689Skan if (Blocks.count(PN->getIncomingBlock(i))) 208169689Skan ++NumPredsFromRegion; 209169689Skan else 210169689Skan ++NumPredsOutsideRegion; 211169689Skan 212169689Skan // If there is one (or fewer) predecessor from outside the region, we don't 213169689Skan // need to do anything special. 214169689Skan if (NumPredsOutsideRegion <= 1) return; 215169689Skan } 216169689Skan 217169689Skan // Otherwise, we need to split the header block into two pieces: one 218169689Skan // containing PHI nodes merging values from outside of the region, and a 219169689Skan // second that contains all of the code for the block and merges back any 220169689Skan // incoming values from inside of the region. 221169689Skan BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI()->getIterator(); 222169689Skan BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, 223169689Skan Header->getName()+".ce"); 224169689Skan 225169689Skan // We only want to code extract the second block now, and it becomes the new 226169689Skan // header of the region. 227169689Skan BasicBlock *OldPred = Header; 228169689Skan Blocks.remove(OldPred); 229169689Skan Blocks.insert(NewBB); 230169689Skan Header = NewBB; 231169689Skan 232169689Skan // Okay, update dominator sets. The blocks that dominate the new one are the 233169689Skan // blocks that dominate TIBB plus the new block itself. 234169689Skan if (DT) 235169689Skan DT->splitBlock(NewBB); 236169689Skan 237169689Skan // Okay, now we need to adjust the PHI nodes and any branches from within the 238169689Skan // region to go to the new header block instead of the old header block. 239169689Skan if (NumPredsFromRegion) { 240169689Skan PHINode *PN = cast<PHINode>(OldPred->begin()); 241169689Skan // Loop over all of the predecessors of OldPred that are in the region, 242169689Skan // changing them to branch to NewBB instead. 243169689Skan for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 244169689Skan if (Blocks.count(PN->getIncomingBlock(i))) { 245169689Skan TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); 246169689Skan TI->replaceUsesOfWith(OldPred, NewBB); 247169689Skan } 248169689Skan 249169689Skan // Okay, everything within the region is now branching to the right block, we 250169689Skan // just have to update the PHI nodes now, inserting PHI nodes into NewBB. 251169689Skan for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { 252169689Skan PHINode *PN = cast<PHINode>(AfterPHIs); 253169689Skan // Create a new PHI node in the new region, which has an incoming value 254169689Skan // from OldPred of PN. 255169689Skan PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, 256169689Skan PN->getName() + ".ce", &NewBB->front()); 257169689Skan NewPN->addIncoming(PN, OldPred); 258169689Skan 259169689Skan // Loop over all of the incoming value in PN, moving them to NewPN if they 260169689Skan // are from the extracted region. 261169689Skan for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { 262169689Skan if (Blocks.count(PN->getIncomingBlock(i))) { 263169689Skan NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); 264169689Skan PN->removeIncomingValue(i); 265169689Skan --i; 266169689Skan } 267169689Skan } 268169689Skan } 269169689Skan } 270169689Skan} 271169689Skan 272169689Skanvoid CodeExtractor::splitReturnBlocks() { 273169689Skan for (BasicBlock *Block : Blocks) 274169689Skan if (ReturnInst *RI = dyn_cast<ReturnInst>(Block->getTerminator())) { 275169689Skan BasicBlock *New = 276169689Skan Block->splitBasicBlock(RI->getIterator(), Block->getName() + ".ret"); 277169689Skan if (DT) { 278169689Skan // Old dominates New. New node dominates all other nodes dominated 279169689Skan // by Old. 280169689Skan DomTreeNode *OldNode = DT->getNode(Block); 281169689Skan SmallVector<DomTreeNode *, 8> Children(OldNode->begin(), 282169689Skan OldNode->end()); 283169689Skan 284169689Skan DomTreeNode *NewNode = DT->addNewBlock(New, Block); 285169689Skan 286169689Skan for (DomTreeNode *I : Children) 287169689Skan DT->changeImmediateDominator(I, NewNode); 288169689Skan } 289169689Skan } 290169689Skan} 291169689Skan 292169689Skan/// constructFunction - make a function based on inputs and outputs, as follows: 293169689Skan/// f(in0, ..., inN, out0, ..., outN) 294169689Skan/// 295169689SkanFunction *CodeExtractor::constructFunction(const ValueSet &inputs, 296169689Skan const ValueSet &outputs, 297169689Skan BasicBlock *header, 298169689Skan BasicBlock *newRootNode, 299169689Skan BasicBlock *newHeader, 300169689Skan Function *oldFunction, 301169689Skan Module *M) { 302169689Skan DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); 303169689Skan DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); 304169689Skan 305169689Skan // This function returns unsigned, outputs will go back by reference. 306169689Skan switch (NumExitBlocks) { 307169689Skan case 0: 308169689Skan case 1: RetTy = Type::getVoidTy(header->getContext()); break; 309169689Skan case 2: RetTy = Type::getInt1Ty(header->getContext()); break; 310169689Skan default: RetTy = Type::getInt16Ty(header->getContext()); break; 311169689Skan } 312169689Skan 313169689Skan std::vector<Type*> paramTy; 314169689Skan 315169689Skan // Add the types of the input values to the function's argument list 316169689Skan for (Value *value : inputs) { 317169689Skan DEBUG(dbgs() << "value used in func: " << *value << "\n"); 318169689Skan paramTy.push_back(value->getType()); 319169689Skan } 320169689Skan 321169689Skan // Add the types of the output values to the function's argument list. 322169689Skan for (Value *output : outputs) { 323169689Skan DEBUG(dbgs() << "instr used in func: " << *output << "\n"); 324169689Skan if (AggregateArgs) 325169689Skan paramTy.push_back(output->getType()); 326169689Skan else 327169689Skan paramTy.push_back(PointerType::getUnqual(output->getType())); 328169689Skan } 329169689Skan 330169689Skan DEBUG({ 331169689Skan dbgs() << "Function type: " << *RetTy << " f("; 332169689Skan for (Type *i : paramTy) 333169689Skan dbgs() << *i << ", "; 334169689Skan dbgs() << ")\n"; 335169689Skan }); 336169689Skan 337169689Skan StructType *StructTy; 338169689Skan if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 339169689Skan StructTy = StructType::get(M->getContext(), paramTy); 340169689Skan paramTy.clear(); 341169689Skan paramTy.push_back(PointerType::getUnqual(StructTy)); 342169689Skan } 343169689Skan FunctionType *funcType = 344169689Skan FunctionType::get(RetTy, paramTy, false); 345169689Skan 346169689Skan // Create the new function 347169689Skan Function *newFunction = Function::Create(funcType, 348169689Skan GlobalValue::InternalLinkage, 349169689Skan oldFunction->getName() + "_" + 350169689Skan header->getName(), M); 351169689Skan // If the old function is no-throw, so is the new one. 352169689Skan if (oldFunction->doesNotThrow()) 353169689Skan newFunction->setDoesNotThrow(); 354169689Skan 355169689Skan // Inherit the uwtable attribute if we need to. 356169689Skan if (oldFunction->hasUWTable()) 357169689Skan newFunction->setHasUWTable(); 358169689Skan 359169689Skan // Inherit all of the target dependent attributes. 360169689Skan // (e.g. If the extracted region contains a call to an x86.sse 361169689Skan // instruction we need to make sure that the extracted region has the 362169689Skan // "target-features" attribute allowing it to be lowered. 363169689Skan // FIXME: This should be changed to check to see if a specific 364169689Skan // attribute can not be inherited. 365169689Skan AttributeSet OldFnAttrs = oldFunction->getAttributes().getFnAttributes(); 366169689Skan AttrBuilder AB(OldFnAttrs, AttributeSet::FunctionIndex); 367169689Skan for (auto Attr : AB.td_attrs()) 368169689Skan newFunction->addFnAttr(Attr.first, Attr.second); 369169689Skan 370169689Skan newFunction->getBasicBlockList().push_back(newRootNode); 371169689Skan 372 // Create an iterator to name all of the arguments we inserted. 373 Function::arg_iterator AI = newFunction->arg_begin(); 374 375 // Rewrite all users of the inputs in the extracted region to use the 376 // arguments (or appropriate addressing into struct) instead. 377 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 378 Value *RewriteVal; 379 if (AggregateArgs) { 380 Value *Idx[2]; 381 Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); 382 Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); 383 TerminatorInst *TI = newFunction->begin()->getTerminator(); 384 GetElementPtrInst *GEP = GetElementPtrInst::Create( 385 StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI); 386 RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); 387 } else 388 RewriteVal = &*AI++; 389 390 std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end()); 391 for (User *use : Users) 392 if (Instruction *inst = dyn_cast<Instruction>(use)) 393 if (Blocks.count(inst->getParent())) 394 inst->replaceUsesOfWith(inputs[i], RewriteVal); 395 } 396 397 // Set names for input and output arguments. 398 if (!AggregateArgs) { 399 AI = newFunction->arg_begin(); 400 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) 401 AI->setName(inputs[i]->getName()); 402 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) 403 AI->setName(outputs[i]->getName()+".out"); 404 } 405 406 // Rewrite branches to basic blocks outside of the loop to new dummy blocks 407 // within the new function. This must be done before we lose track of which 408 // blocks were originally in the code region. 409 std::vector<User*> Users(header->user_begin(), header->user_end()); 410 for (unsigned i = 0, e = Users.size(); i != e; ++i) 411 // The BasicBlock which contains the branch is not in the region 412 // modify the branch target to a new block 413 if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i])) 414 if (!Blocks.count(TI->getParent()) && 415 TI->getParent()->getParent() == oldFunction) 416 TI->replaceUsesOfWith(header, newHeader); 417 418 return newFunction; 419} 420 421/// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI 422/// that uses the value within the basic block, and return the predecessor 423/// block associated with that use, or return 0 if none is found. 424static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { 425 for (Use &U : Used->uses()) { 426 PHINode *P = dyn_cast<PHINode>(U.getUser()); 427 if (P && P->getParent() == BB) 428 return P->getIncomingBlock(U); 429 } 430 431 return nullptr; 432} 433 434/// emitCallAndSwitchStatement - This method sets up the caller side by adding 435/// the call instruction, splitting any PHI nodes in the header block as 436/// necessary. 437void CodeExtractor:: 438emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, 439 ValueSet &inputs, ValueSet &outputs) { 440 // Emit a call to the new function, passing in: *pointer to struct (if 441 // aggregating parameters), or plan inputs and allocated memory for outputs 442 std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; 443 444 LLVMContext &Context = newFunction->getContext(); 445 446 // Add inputs as params, or to be filled into the struct 447 for (Value *input : inputs) 448 if (AggregateArgs) 449 StructValues.push_back(input); 450 else 451 params.push_back(input); 452 453 // Create allocas for the outputs 454 for (Value *output : outputs) { 455 if (AggregateArgs) { 456 StructValues.push_back(output); 457 } else { 458 AllocaInst *alloca = 459 new AllocaInst(output->getType(), nullptr, output->getName() + ".loc", 460 &codeReplacer->getParent()->front().front()); 461 ReloadOutputs.push_back(alloca); 462 params.push_back(alloca); 463 } 464 } 465 466 StructType *StructArgTy = nullptr; 467 AllocaInst *Struct = nullptr; 468 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 469 std::vector<Type*> ArgTypes; 470 for (ValueSet::iterator v = StructValues.begin(), 471 ve = StructValues.end(); v != ve; ++v) 472 ArgTypes.push_back((*v)->getType()); 473 474 // Allocate a struct at the beginning of this function 475 StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); 476 Struct = new AllocaInst(StructArgTy, nullptr, "structArg", 477 &codeReplacer->getParent()->front().front()); 478 params.push_back(Struct); 479 480 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 481 Value *Idx[2]; 482 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 483 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); 484 GetElementPtrInst *GEP = GetElementPtrInst::Create( 485 StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName()); 486 codeReplacer->getInstList().push_back(GEP); 487 StoreInst *SI = new StoreInst(StructValues[i], GEP); 488 codeReplacer->getInstList().push_back(SI); 489 } 490 } 491 492 // Emit the call to the function 493 CallInst *call = CallInst::Create(newFunction, params, 494 NumExitBlocks > 1 ? "targetBlock" : ""); 495 codeReplacer->getInstList().push_back(call); 496 497 Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); 498 unsigned FirstOut = inputs.size(); 499 if (!AggregateArgs) 500 std::advance(OutputArgBegin, inputs.size()); 501 502 // Reload the outputs passed in by reference 503 for (unsigned i = 0, e = outputs.size(); i != e; ++i) { 504 Value *Output = nullptr; 505 if (AggregateArgs) { 506 Value *Idx[2]; 507 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 508 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); 509 GetElementPtrInst *GEP = GetElementPtrInst::Create( 510 StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName()); 511 codeReplacer->getInstList().push_back(GEP); 512 Output = GEP; 513 } else { 514 Output = ReloadOutputs[i]; 515 } 516 LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); 517 Reloads.push_back(load); 518 codeReplacer->getInstList().push_back(load); 519 std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end()); 520 for (unsigned u = 0, e = Users.size(); u != e; ++u) { 521 Instruction *inst = cast<Instruction>(Users[u]); 522 if (!Blocks.count(inst->getParent())) 523 inst->replaceUsesOfWith(outputs[i], load); 524 } 525 } 526 527 // Now we can emit a switch statement using the call as a value. 528 SwitchInst *TheSwitch = 529 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), 530 codeReplacer, 0, codeReplacer); 531 532 // Since there may be multiple exits from the original region, make the new 533 // function return an unsigned, switch on that number. This loop iterates 534 // over all of the blocks in the extracted region, updating any terminator 535 // instructions in the to-be-extracted region that branch to blocks that are 536 // not in the region to be extracted. 537 std::map<BasicBlock*, BasicBlock*> ExitBlockMap; 538 539 unsigned switchVal = 0; 540 for (BasicBlock *Block : Blocks) { 541 TerminatorInst *TI = Block->getTerminator(); 542 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 543 if (!Blocks.count(TI->getSuccessor(i))) { 544 BasicBlock *OldTarget = TI->getSuccessor(i); 545 // add a new basic block which returns the appropriate value 546 BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; 547 if (!NewTarget) { 548 // If we don't already have an exit stub for this non-extracted 549 // destination, create one now! 550 NewTarget = BasicBlock::Create(Context, 551 OldTarget->getName() + ".exitStub", 552 newFunction); 553 unsigned SuccNum = switchVal++; 554 555 Value *brVal = nullptr; 556 switch (NumExitBlocks) { 557 case 0: 558 case 1: break; // No value needed. 559 case 2: // Conditional branch, return a bool 560 brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); 561 break; 562 default: 563 brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); 564 break; 565 } 566 567 ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); 568 569 // Update the switch instruction. 570 TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), 571 SuccNum), 572 OldTarget); 573 574 // Restore values just before we exit 575 Function::arg_iterator OAI = OutputArgBegin; 576 for (unsigned out = 0, e = outputs.size(); out != e; ++out) { 577 // For an invoke, the normal destination is the only one that is 578 // dominated by the result of the invocation 579 BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent(); 580 581 bool DominatesDef = true; 582 583 BasicBlock *NormalDest = nullptr; 584 if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out])) 585 NormalDest = Invoke->getNormalDest(); 586 587 if (NormalDest) { 588 DefBlock = NormalDest; 589 590 // Make sure we are looking at the original successor block, not 591 // at a newly inserted exit block, which won't be in the dominator 592 // info. 593 for (const auto &I : ExitBlockMap) 594 if (DefBlock == I.second) { 595 DefBlock = I.first; 596 break; 597 } 598 599 // In the extract block case, if the block we are extracting ends 600 // with an invoke instruction, make sure that we don't emit a 601 // store of the invoke value for the unwind block. 602 if (!DT && DefBlock != OldTarget) 603 DominatesDef = false; 604 } 605 606 if (DT) { 607 DominatesDef = DT->dominates(DefBlock, OldTarget); 608 609 // If the output value is used by a phi in the target block, 610 // then we need to test for dominance of the phi's predecessor 611 // instead. Unfortunately, this a little complicated since we 612 // have already rewritten uses of the value to uses of the reload. 613 BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], 614 OldTarget); 615 if (pred && DT && DT->dominates(DefBlock, pred)) 616 DominatesDef = true; 617 } 618 619 if (DominatesDef) { 620 if (AggregateArgs) { 621 Value *Idx[2]; 622 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 623 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), 624 FirstOut+out); 625 GetElementPtrInst *GEP = GetElementPtrInst::Create( 626 StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(), 627 NTRet); 628 new StoreInst(outputs[out], GEP, NTRet); 629 } else { 630 new StoreInst(outputs[out], &*OAI, NTRet); 631 } 632 } 633 // Advance output iterator even if we don't emit a store 634 if (!AggregateArgs) ++OAI; 635 } 636 } 637 638 // rewrite the original branch instruction with this new target 639 TI->setSuccessor(i, NewTarget); 640 } 641 } 642 643 // Now that we've done the deed, simplify the switch instruction. 644 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); 645 switch (NumExitBlocks) { 646 case 0: 647 // There are no successors (the block containing the switch itself), which 648 // means that previously this was the last part of the function, and hence 649 // this should be rewritten as a `ret' 650 651 // Check if the function should return a value 652 if (OldFnRetTy->isVoidTy()) { 653 ReturnInst::Create(Context, nullptr, TheSwitch); // Return void 654 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { 655 // return what we have 656 ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); 657 } else { 658 // Otherwise we must have code extracted an unwind or something, just 659 // return whatever we want. 660 ReturnInst::Create(Context, 661 Constant::getNullValue(OldFnRetTy), TheSwitch); 662 } 663 664 TheSwitch->eraseFromParent(); 665 break; 666 case 1: 667 // Only a single destination, change the switch into an unconditional 668 // branch. 669 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); 670 TheSwitch->eraseFromParent(); 671 break; 672 case 2: 673 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), 674 call, TheSwitch); 675 TheSwitch->eraseFromParent(); 676 break; 677 default: 678 // Otherwise, make the default destination of the switch instruction be one 679 // of the other successors. 680 TheSwitch->setCondition(call); 681 TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); 682 // Remove redundant case 683 TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); 684 break; 685 } 686} 687 688void CodeExtractor::moveCodeToFunction(Function *newFunction) { 689 Function *oldFunc = (*Blocks.begin())->getParent(); 690 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); 691 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); 692 693 for (BasicBlock *Block : Blocks) { 694 // Delete the basic block from the old function, and the list of blocks 695 oldBlocks.remove(Block); 696 697 // Insert this basic block into the new function 698 newBlocks.push_back(Block); 699 } 700} 701 702void CodeExtractor::calculateNewCallTerminatorWeights( 703 BasicBlock *CodeReplacer, 704 DenseMap<BasicBlock *, BlockFrequency> &ExitWeights, 705 BranchProbabilityInfo *BPI) { 706 typedef BlockFrequencyInfoImplBase::Distribution Distribution; 707 typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; 708 709 // Update the branch weights for the exit block. 710 TerminatorInst *TI = CodeReplacer->getTerminator(); 711 SmallVector<unsigned, 8> BranchWeights(TI->getNumSuccessors(), 0); 712 713 // Block Frequency distribution with dummy node. 714 Distribution BranchDist; 715 716 // Add each of the frequencies of the successors. 717 for (unsigned i = 0, e = TI->getNumSuccessors(); i < e; ++i) { 718 BlockNode ExitNode(i); 719 uint64_t ExitFreq = ExitWeights[TI->getSuccessor(i)].getFrequency(); 720 if (ExitFreq != 0) 721 BranchDist.addExit(ExitNode, ExitFreq); 722 else 723 BPI->setEdgeProbability(CodeReplacer, i, BranchProbability::getZero()); 724 } 725 726 // Check for no total weight. 727 if (BranchDist.Total == 0) 728 return; 729 730 // Normalize the distribution so that they can fit in unsigned. 731 BranchDist.normalize(); 732 733 // Create normalized branch weights and set the metadata. 734 for (unsigned I = 0, E = BranchDist.Weights.size(); I < E; ++I) { 735 const auto &Weight = BranchDist.Weights[I]; 736 737 // Get the weight and update the current BFI. 738 BranchWeights[Weight.TargetNode.Index] = Weight.Amount; 739 BranchProbability BP(Weight.Amount, BranchDist.Total); 740 BPI->setEdgeProbability(CodeReplacer, Weight.TargetNode.Index, BP); 741 } 742 TI->setMetadata( 743 LLVMContext::MD_prof, 744 MDBuilder(TI->getContext()).createBranchWeights(BranchWeights)); 745} 746 747Function *CodeExtractor::extractCodeRegion() { 748 if (!isEligible()) 749 return nullptr; 750 751 ValueSet inputs, outputs; 752 753 // Assumption: this is a single-entry code region, and the header is the first 754 // block in the region. 755 BasicBlock *header = *Blocks.begin(); 756 757 // Calculate the entry frequency of the new function before we change the root 758 // block. 759 BlockFrequency EntryFreq; 760 if (BFI) { 761 assert(BPI && "Both BPI and BFI are required to preserve profile info"); 762 for (BasicBlock *Pred : predecessors(header)) { 763 if (Blocks.count(Pred)) 764 continue; 765 EntryFreq += 766 BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, header); 767 } 768 } 769 770 // If we have to split PHI nodes or the entry block, do so now. 771 severSplitPHINodes(header); 772 773 // If we have any return instructions in the region, split those blocks so 774 // that the return is not in the region. 775 splitReturnBlocks(); 776 777 Function *oldFunction = header->getParent(); 778 779 // This takes place of the original loop 780 BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), 781 "codeRepl", oldFunction, 782 header); 783 784 // The new function needs a root node because other nodes can branch to the 785 // head of the region, but the entry node of a function cannot have preds. 786 BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), 787 "newFuncRoot"); 788 newFuncRoot->getInstList().push_back(BranchInst::Create(header)); 789 790 // Find inputs to, outputs from the code region. 791 findInputsOutputs(inputs, outputs); 792 793 // Calculate the exit blocks for the extracted region and the total exit 794 // weights for each of those blocks. 795 DenseMap<BasicBlock *, BlockFrequency> ExitWeights; 796 SmallPtrSet<BasicBlock *, 1> ExitBlocks; 797 for (BasicBlock *Block : Blocks) { 798 for (succ_iterator SI = succ_begin(Block), SE = succ_end(Block); SI != SE; 799 ++SI) { 800 if (!Blocks.count(*SI)) { 801 // Update the branch weight for this successor. 802 if (BFI) { 803 BlockFrequency &BF = ExitWeights[*SI]; 804 BF += BFI->getBlockFreq(Block) * BPI->getEdgeProbability(Block, *SI); 805 } 806 ExitBlocks.insert(*SI); 807 } 808 } 809 } 810 NumExitBlocks = ExitBlocks.size(); 811 812 // Construct new function based on inputs/outputs & add allocas for all defs. 813 Function *newFunction = constructFunction(inputs, outputs, header, 814 newFuncRoot, 815 codeReplacer, oldFunction, 816 oldFunction->getParent()); 817 818 // Update the entry count of the function. 819 if (BFI) { 820 Optional<uint64_t> EntryCount = 821 BFI->getProfileCountFromFreq(EntryFreq.getFrequency()); 822 if (EntryCount.hasValue()) 823 newFunction->setEntryCount(EntryCount.getValue()); 824 BFI->setBlockFreq(codeReplacer, EntryFreq.getFrequency()); 825 } 826 827 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); 828 829 moveCodeToFunction(newFunction); 830 831 // Update the branch weights for the exit block. 832 if (BFI && NumExitBlocks > 1) 833 calculateNewCallTerminatorWeights(codeReplacer, ExitWeights, BPI); 834 835 // Loop over all of the PHI nodes in the header block, and change any 836 // references to the old incoming edge to be the new incoming edge. 837 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { 838 PHINode *PN = cast<PHINode>(I); 839 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 840 if (!Blocks.count(PN->getIncomingBlock(i))) 841 PN->setIncomingBlock(i, newFuncRoot); 842 } 843 844 // Look at all successors of the codeReplacer block. If any of these blocks 845 // had PHI nodes in them, we need to update the "from" block to be the code 846 // replacer, not the original block in the extracted region. 847 std::vector<BasicBlock*> Succs(succ_begin(codeReplacer), 848 succ_end(codeReplacer)); 849 for (unsigned i = 0, e = Succs.size(); i != e; ++i) 850 for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { 851 PHINode *PN = cast<PHINode>(I); 852 std::set<BasicBlock*> ProcessedPreds; 853 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 854 if (Blocks.count(PN->getIncomingBlock(i))) { 855 if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) 856 PN->setIncomingBlock(i, codeReplacer); 857 else { 858 // There were multiple entries in the PHI for this block, now there 859 // is only one, so remove the duplicated entries. 860 PN->removeIncomingValue(i, false); 861 --i; --e; 862 } 863 } 864 } 865 866 //cerr << "NEW FUNCTION: " << *newFunction; 867 // verifyFunction(*newFunction); 868 869 // cerr << "OLD FUNCTION: " << *oldFunction; 870 // verifyFunction(*oldFunction); 871 872 DEBUG(if (verifyFunction(*newFunction)) 873 report_fatal_error("verifyFunction failed!")); 874 return newFunction; 875} 876