CodeExtractor.cpp revision 218893
1193323Sed//===- CodeExtractor.cpp - Pull code region into a new function -----------===// 2193323Sed// 3193323Sed// The LLVM Compiler Infrastructure 4193323Sed// 5193323Sed// This file is distributed under the University of Illinois Open Source 6193323Sed// License. See LICENSE.TXT for details. 7193323Sed// 8193323Sed//===----------------------------------------------------------------------===// 9193323Sed// 10193323Sed// This file implements the interface to tear out a code region, such as an 11193323Sed// individual loop or a parallel section, into a new function, replacing it with 12193323Sed// a call to the new function. 13193323Sed// 14193323Sed//===----------------------------------------------------------------------===// 15193323Sed 16193323Sed#include "llvm/Transforms/Utils/FunctionUtils.h" 17193323Sed#include "llvm/Constants.h" 18193323Sed#include "llvm/DerivedTypes.h" 19193323Sed#include "llvm/Instructions.h" 20193323Sed#include "llvm/Intrinsics.h" 21198090Srdivacky#include "llvm/LLVMContext.h" 22193323Sed#include "llvm/Module.h" 23193323Sed#include "llvm/Pass.h" 24193323Sed#include "llvm/Analysis/Dominators.h" 25193323Sed#include "llvm/Analysis/LoopInfo.h" 26193323Sed#include "llvm/Analysis/Verifier.h" 27193323Sed#include "llvm/Transforms/Utils/BasicBlockUtils.h" 28193323Sed#include "llvm/Support/CommandLine.h" 29193323Sed#include "llvm/Support/Debug.h" 30198090Srdivacky#include "llvm/Support/ErrorHandling.h" 31198090Srdivacky#include "llvm/Support/raw_ostream.h" 32202375Srdivacky#include "llvm/ADT/SetVector.h" 33193323Sed#include "llvm/ADT/StringExtras.h" 34193323Sed#include <algorithm> 35193323Sed#include <set> 36193323Sedusing namespace llvm; 37193323Sed 38193323Sed// Provide a command-line option to aggregate function arguments into a struct 39193323Sed// for functions produced by the code extractor. This is useful when converting 40193323Sed// extracted functions to pthread-based code, as only one argument (void*) can 41193323Sed// be passed in to pthread_create(). 42193323Sedstatic cl::opt<bool> 43193323SedAggregateArgsOpt("aggregate-extracted-args", cl::Hidden, 44193323Sed cl::desc("Aggregate arguments to code-extracted functions")); 45193323Sed 46193323Sednamespace { 47198892Srdivacky class CodeExtractor { 48202375Srdivacky typedef SetVector<Value*> Values; 49202375Srdivacky SetVector<BasicBlock*> BlocksToExtract; 50193323Sed DominatorTree* DT; 51193323Sed bool AggregateArgs; 52193323Sed unsigned NumExitBlocks; 53193323Sed const Type *RetTy; 54193323Sed public: 55193323Sed CodeExtractor(DominatorTree* dt = 0, bool AggArgs = false) 56193323Sed : DT(dt), AggregateArgs(AggArgs||AggregateArgsOpt), NumExitBlocks(~0U) {} 57193323Sed 58193323Sed Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code); 59193323Sed 60193323Sed bool isEligible(const std::vector<BasicBlock*> &code); 61193323Sed 62193323Sed private: 63193323Sed /// definedInRegion - Return true if the specified value is defined in the 64193323Sed /// extracted region. 65193323Sed bool definedInRegion(Value *V) const { 66193323Sed if (Instruction *I = dyn_cast<Instruction>(V)) 67193323Sed if (BlocksToExtract.count(I->getParent())) 68193323Sed return true; 69193323Sed return false; 70193323Sed } 71193323Sed 72193323Sed /// definedInCaller - Return true if the specified value is defined in the 73193323Sed /// function being code extracted, but not in the region being extracted. 74193323Sed /// These values must be passed in as live-ins to the function. 75193323Sed bool definedInCaller(Value *V) const { 76193323Sed if (isa<Argument>(V)) return true; 77193323Sed if (Instruction *I = dyn_cast<Instruction>(V)) 78193323Sed if (!BlocksToExtract.count(I->getParent())) 79193323Sed return true; 80193323Sed return false; 81193323Sed } 82193323Sed 83193323Sed void severSplitPHINodes(BasicBlock *&Header); 84193323Sed void splitReturnBlocks(); 85193323Sed void findInputsOutputs(Values &inputs, Values &outputs); 86193323Sed 87193323Sed Function *constructFunction(const Values &inputs, 88193323Sed const Values &outputs, 89193323Sed BasicBlock *header, 90193323Sed BasicBlock *newRootNode, BasicBlock *newHeader, 91193323Sed Function *oldFunction, Module *M); 92193323Sed 93193323Sed void moveCodeToFunction(Function *newFunction); 94193323Sed 95193323Sed void emitCallAndSwitchStatement(Function *newFunction, 96193323Sed BasicBlock *newHeader, 97193323Sed Values &inputs, 98193323Sed Values &outputs); 99193323Sed 100193323Sed }; 101193323Sed} 102193323Sed 103193323Sed/// severSplitPHINodes - If a PHI node has multiple inputs from outside of the 104193323Sed/// region, we need to split the entry block of the region so that the PHI node 105193323Sed/// is easier to deal with. 106193323Sedvoid CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { 107193323Sed bool HasPredsFromRegion = false; 108193323Sed unsigned NumPredsOutsideRegion = 0; 109193323Sed 110193323Sed if (Header != &Header->getParent()->getEntryBlock()) { 111193323Sed PHINode *PN = dyn_cast<PHINode>(Header->begin()); 112193323Sed if (!PN) return; // No PHI nodes. 113193323Sed 114193323Sed // If the header node contains any PHI nodes, check to see if there is more 115193323Sed // than one entry from outside the region. If so, we need to sever the 116193323Sed // header block into two. 117193323Sed for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 118193323Sed if (BlocksToExtract.count(PN->getIncomingBlock(i))) 119193323Sed HasPredsFromRegion = true; 120193323Sed else 121193323Sed ++NumPredsOutsideRegion; 122193323Sed 123193323Sed // If there is one (or fewer) predecessor from outside the region, we don't 124193323Sed // need to do anything special. 125193323Sed if (NumPredsOutsideRegion <= 1) return; 126193323Sed } 127193323Sed 128193323Sed // Otherwise, we need to split the header block into two pieces: one 129193323Sed // containing PHI nodes merging values from outside of the region, and a 130193323Sed // second that contains all of the code for the block and merges back any 131193323Sed // incoming values from inside of the region. 132193323Sed BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI(); 133193323Sed BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, 134193323Sed Header->getName()+".ce"); 135193323Sed 136193323Sed // We only want to code extract the second block now, and it becomes the new 137193323Sed // header of the region. 138193323Sed BasicBlock *OldPred = Header; 139202375Srdivacky BlocksToExtract.remove(OldPred); 140193323Sed BlocksToExtract.insert(NewBB); 141193323Sed Header = NewBB; 142193323Sed 143193323Sed // Okay, update dominator sets. The blocks that dominate the new one are the 144193323Sed // blocks that dominate TIBB plus the new block itself. 145193323Sed if (DT) 146193323Sed DT->splitBlock(NewBB); 147193323Sed 148193323Sed // Okay, now we need to adjust the PHI nodes and any branches from within the 149193323Sed // region to go to the new header block instead of the old header block. 150193323Sed if (HasPredsFromRegion) { 151193323Sed PHINode *PN = cast<PHINode>(OldPred->begin()); 152193323Sed // Loop over all of the predecessors of OldPred that are in the region, 153193323Sed // changing them to branch to NewBB instead. 154193323Sed for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 155193323Sed if (BlocksToExtract.count(PN->getIncomingBlock(i))) { 156193323Sed TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); 157193323Sed TI->replaceUsesOfWith(OldPred, NewBB); 158193323Sed } 159193323Sed 160193323Sed // Okay, everthing within the region is now branching to the right block, we 161193323Sed // just have to update the PHI nodes now, inserting PHI nodes into NewBB. 162193323Sed for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { 163193323Sed PHINode *PN = cast<PHINode>(AfterPHIs); 164193323Sed // Create a new PHI node in the new region, which has an incoming value 165193323Sed // from OldPred of PN. 166193323Sed PHINode *NewPN = PHINode::Create(PN->getType(), PN->getName()+".ce", 167193323Sed NewBB->begin()); 168193323Sed NewPN->addIncoming(PN, OldPred); 169193323Sed 170193323Sed // Loop over all of the incoming value in PN, moving them to NewPN if they 171193323Sed // are from the extracted region. 172193323Sed for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { 173193323Sed if (BlocksToExtract.count(PN->getIncomingBlock(i))) { 174193323Sed NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); 175193323Sed PN->removeIncomingValue(i); 176193323Sed --i; 177193323Sed } 178193323Sed } 179193323Sed } 180193323Sed } 181193323Sed} 182193323Sed 183193323Sedvoid CodeExtractor::splitReturnBlocks() { 184202375Srdivacky for (SetVector<BasicBlock*>::iterator I = BlocksToExtract.begin(), 185193323Sed E = BlocksToExtract.end(); I != E; ++I) 186198090Srdivacky if (ReturnInst *RI = dyn_cast<ReturnInst>((*I)->getTerminator())) { 187198090Srdivacky BasicBlock *New = (*I)->splitBasicBlock(RI, (*I)->getName()+".ret"); 188198090Srdivacky if (DT) { 189218893Sdim // Old dominates New. New node dominates all other nodes dominated 190218893Sdim // by Old. 191198090Srdivacky DomTreeNode *OldNode = DT->getNode(*I); 192198090Srdivacky SmallVector<DomTreeNode*, 8> Children; 193198090Srdivacky for (DomTreeNode::iterator DI = OldNode->begin(), DE = OldNode->end(); 194198090Srdivacky DI != DE; ++DI) 195198090Srdivacky Children.push_back(*DI); 196198090Srdivacky 197198090Srdivacky DomTreeNode *NewNode = DT->addNewBlock(New, *I); 198198090Srdivacky 199198090Srdivacky for (SmallVector<DomTreeNode*, 8>::iterator I = Children.begin(), 200198090Srdivacky E = Children.end(); I != E; ++I) 201198090Srdivacky DT->changeImmediateDominator(*I, NewNode); 202198090Srdivacky } 203198090Srdivacky } 204193323Sed} 205193323Sed 206193323Sed// findInputsOutputs - Find inputs to, outputs from the code region. 207193323Sed// 208193323Sedvoid CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs) { 209193323Sed std::set<BasicBlock*> ExitBlocks; 210202375Srdivacky for (SetVector<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(), 211193323Sed ce = BlocksToExtract.end(); ci != ce; ++ci) { 212193323Sed BasicBlock *BB = *ci; 213193323Sed 214193323Sed for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { 215193323Sed // If a used value is defined outside the region, it's an input. If an 216193323Sed // instruction is used outside the region, it's an output. 217193323Sed for (User::op_iterator O = I->op_begin(), E = I->op_end(); O != E; ++O) 218193323Sed if (definedInCaller(*O)) 219202375Srdivacky inputs.insert(*O); 220193323Sed 221193323Sed // Consider uses of this instruction (outputs). 222193323Sed for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); 223193323Sed UI != E; ++UI) 224193323Sed if (!definedInRegion(*UI)) { 225202375Srdivacky outputs.insert(I); 226193323Sed break; 227193323Sed } 228193323Sed } // for: insts 229193323Sed 230193323Sed // Keep track of the exit blocks from the region. 231193323Sed TerminatorInst *TI = BB->getTerminator(); 232193323Sed for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 233193323Sed if (!BlocksToExtract.count(TI->getSuccessor(i))) 234193323Sed ExitBlocks.insert(TI->getSuccessor(i)); 235193323Sed } // for: basic blocks 236193323Sed 237193323Sed NumExitBlocks = ExitBlocks.size(); 238193323Sed} 239193323Sed 240193323Sed/// constructFunction - make a function based on inputs and outputs, as follows: 241193323Sed/// f(in0, ..., inN, out0, ..., outN) 242193323Sed/// 243193323SedFunction *CodeExtractor::constructFunction(const Values &inputs, 244193323Sed const Values &outputs, 245193323Sed BasicBlock *header, 246193323Sed BasicBlock *newRootNode, 247193323Sed BasicBlock *newHeader, 248193323Sed Function *oldFunction, 249193323Sed Module *M) { 250202375Srdivacky DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); 251202375Srdivacky DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); 252193323Sed 253193323Sed // This function returns unsigned, outputs will go back by reference. 254193323Sed switch (NumExitBlocks) { 255193323Sed case 0: 256198090Srdivacky case 1: RetTy = Type::getVoidTy(header->getContext()); break; 257198090Srdivacky case 2: RetTy = Type::getInt1Ty(header->getContext()); break; 258198090Srdivacky default: RetTy = Type::getInt16Ty(header->getContext()); break; 259193323Sed } 260193323Sed 261193323Sed std::vector<const Type*> paramTy; 262193323Sed 263193323Sed // Add the types of the input values to the function's argument list 264193323Sed for (Values::const_iterator i = inputs.begin(), 265193323Sed e = inputs.end(); i != e; ++i) { 266193323Sed const Value *value = *i; 267202375Srdivacky DEBUG(dbgs() << "value used in func: " << *value << "\n"); 268193323Sed paramTy.push_back(value->getType()); 269193323Sed } 270193323Sed 271193323Sed // Add the types of the output values to the function's argument list. 272193323Sed for (Values::const_iterator I = outputs.begin(), E = outputs.end(); 273193323Sed I != E; ++I) { 274202375Srdivacky DEBUG(dbgs() << "instr used in func: " << **I << "\n"); 275193323Sed if (AggregateArgs) 276193323Sed paramTy.push_back((*I)->getType()); 277193323Sed else 278193323Sed paramTy.push_back(PointerType::getUnqual((*I)->getType())); 279193323Sed } 280193323Sed 281202375Srdivacky DEBUG(dbgs() << "Function type: " << *RetTy << " f("); 282193323Sed for (std::vector<const Type*>::iterator i = paramTy.begin(), 283193323Sed e = paramTy.end(); i != e; ++i) 284202375Srdivacky DEBUG(dbgs() << **i << ", "); 285202375Srdivacky DEBUG(dbgs() << ")\n"); 286193323Sed 287193323Sed if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 288198090Srdivacky PointerType *StructPtr = 289198090Srdivacky PointerType::getUnqual(StructType::get(M->getContext(), paramTy)); 290193323Sed paramTy.clear(); 291193323Sed paramTy.push_back(StructPtr); 292193323Sed } 293198090Srdivacky const FunctionType *funcType = 294198090Srdivacky FunctionType::get(RetTy, paramTy, false); 295193323Sed 296193323Sed // Create the new function 297193323Sed Function *newFunction = Function::Create(funcType, 298193323Sed GlobalValue::InternalLinkage, 299193323Sed oldFunction->getName() + "_" + 300193323Sed header->getName(), M); 301193323Sed // If the old function is no-throw, so is the new one. 302193323Sed if (oldFunction->doesNotThrow()) 303193323Sed newFunction->setDoesNotThrow(true); 304193323Sed 305193323Sed newFunction->getBasicBlockList().push_back(newRootNode); 306193323Sed 307193323Sed // Create an iterator to name all of the arguments we inserted. 308193323Sed Function::arg_iterator AI = newFunction->arg_begin(); 309193323Sed 310193323Sed // Rewrite all users of the inputs in the extracted region to use the 311193323Sed // arguments (or appropriate addressing into struct) instead. 312193323Sed for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 313193323Sed Value *RewriteVal; 314193323Sed if (AggregateArgs) { 315193323Sed Value *Idx[2]; 316198090Srdivacky Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); 317198090Srdivacky Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); 318193323Sed TerminatorInst *TI = newFunction->begin()->getTerminator(); 319198090Srdivacky GetElementPtrInst *GEP = 320198090Srdivacky GetElementPtrInst::Create(AI, Idx, Idx+2, 321198090Srdivacky "gep_" + inputs[i]->getName(), TI); 322198090Srdivacky RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); 323193323Sed } else 324193323Sed RewriteVal = AI++; 325193323Sed 326193323Sed std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end()); 327193323Sed for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); 328193323Sed use != useE; ++use) 329193323Sed if (Instruction* inst = dyn_cast<Instruction>(*use)) 330193323Sed if (BlocksToExtract.count(inst->getParent())) 331193323Sed inst->replaceUsesOfWith(inputs[i], RewriteVal); 332193323Sed } 333193323Sed 334193323Sed // Set names for input and output arguments. 335193323Sed if (!AggregateArgs) { 336193323Sed AI = newFunction->arg_begin(); 337193323Sed for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) 338193323Sed AI->setName(inputs[i]->getName()); 339193323Sed for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) 340193323Sed AI->setName(outputs[i]->getName()+".out"); 341193323Sed } 342193323Sed 343193323Sed // Rewrite branches to basic blocks outside of the loop to new dummy blocks 344193323Sed // within the new function. This must be done before we lose track of which 345193323Sed // blocks were originally in the code region. 346193323Sed std::vector<User*> Users(header->use_begin(), header->use_end()); 347193323Sed for (unsigned i = 0, e = Users.size(); i != e; ++i) 348193323Sed // The BasicBlock which contains the branch is not in the region 349193323Sed // modify the branch target to a new block 350193323Sed if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i])) 351193323Sed if (!BlocksToExtract.count(TI->getParent()) && 352193323Sed TI->getParent()->getParent() == oldFunction) 353193323Sed TI->replaceUsesOfWith(header, newHeader); 354193323Sed 355193323Sed return newFunction; 356193323Sed} 357193323Sed 358198090Srdivacky/// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI 359198090Srdivacky/// that uses the value within the basic block, and return the predecessor 360198090Srdivacky/// block associated with that use, or return 0 if none is found. 361198090Srdivackystatic BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { 362198090Srdivacky for (Value::use_iterator UI = Used->use_begin(), 363198090Srdivacky UE = Used->use_end(); UI != UE; ++UI) { 364198090Srdivacky PHINode *P = dyn_cast<PHINode>(*UI); 365198090Srdivacky if (P && P->getParent() == BB) 366198090Srdivacky return P->getIncomingBlock(UI); 367198090Srdivacky } 368198090Srdivacky 369198090Srdivacky return 0; 370198090Srdivacky} 371198090Srdivacky 372193323Sed/// emitCallAndSwitchStatement - This method sets up the caller side by adding 373193323Sed/// the call instruction, splitting any PHI nodes in the header block as 374193323Sed/// necessary. 375193323Sedvoid CodeExtractor:: 376193323SedemitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, 377193323Sed Values &inputs, Values &outputs) { 378193323Sed // Emit a call to the new function, passing in: *pointer to struct (if 379193323Sed // aggregating parameters), or plan inputs and allocated memory for outputs 380198090Srdivacky std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; 381198090Srdivacky 382198090Srdivacky LLVMContext &Context = newFunction->getContext(); 383193323Sed 384193323Sed // Add inputs as params, or to be filled into the struct 385193323Sed for (Values::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) 386193323Sed if (AggregateArgs) 387193323Sed StructValues.push_back(*i); 388193323Sed else 389193323Sed params.push_back(*i); 390193323Sed 391193323Sed // Create allocas for the outputs 392193323Sed for (Values::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { 393193323Sed if (AggregateArgs) { 394193323Sed StructValues.push_back(*i); 395193323Sed } else { 396193323Sed AllocaInst *alloca = 397193323Sed new AllocaInst((*i)->getType(), 0, (*i)->getName()+".loc", 398193323Sed codeReplacer->getParent()->begin()->begin()); 399193323Sed ReloadOutputs.push_back(alloca); 400193323Sed params.push_back(alloca); 401193323Sed } 402193323Sed } 403193323Sed 404193323Sed AllocaInst *Struct = 0; 405193323Sed if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 406193323Sed std::vector<const Type*> ArgTypes; 407193323Sed for (Values::iterator v = StructValues.begin(), 408193323Sed ve = StructValues.end(); v != ve; ++v) 409193323Sed ArgTypes.push_back((*v)->getType()); 410193323Sed 411193323Sed // Allocate a struct at the beginning of this function 412198090Srdivacky Type *StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); 413193323Sed Struct = 414193323Sed new AllocaInst(StructArgTy, 0, "structArg", 415193323Sed codeReplacer->getParent()->begin()->begin()); 416193323Sed params.push_back(Struct); 417193323Sed 418193323Sed for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 419193323Sed Value *Idx[2]; 420198090Srdivacky Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 421198090Srdivacky Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); 422193323Sed GetElementPtrInst *GEP = 423193323Sed GetElementPtrInst::Create(Struct, Idx, Idx + 2, 424193323Sed "gep_" + StructValues[i]->getName()); 425193323Sed codeReplacer->getInstList().push_back(GEP); 426193323Sed StoreInst *SI = new StoreInst(StructValues[i], GEP); 427193323Sed codeReplacer->getInstList().push_back(SI); 428193323Sed } 429193323Sed } 430193323Sed 431193323Sed // Emit the call to the function 432193323Sed CallInst *call = CallInst::Create(newFunction, params.begin(), params.end(), 433193323Sed NumExitBlocks > 1 ? "targetBlock" : ""); 434193323Sed codeReplacer->getInstList().push_back(call); 435193323Sed 436193323Sed Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); 437193323Sed unsigned FirstOut = inputs.size(); 438193323Sed if (!AggregateArgs) 439193323Sed std::advance(OutputArgBegin, inputs.size()); 440193323Sed 441193323Sed // Reload the outputs passed in by reference 442193323Sed for (unsigned i = 0, e = outputs.size(); i != e; ++i) { 443193323Sed Value *Output = 0; 444193323Sed if (AggregateArgs) { 445193323Sed Value *Idx[2]; 446198090Srdivacky Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 447198090Srdivacky Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); 448193323Sed GetElementPtrInst *GEP 449193323Sed = GetElementPtrInst::Create(Struct, Idx, Idx + 2, 450193323Sed "gep_reload_" + outputs[i]->getName()); 451193323Sed codeReplacer->getInstList().push_back(GEP); 452193323Sed Output = GEP; 453193323Sed } else { 454193323Sed Output = ReloadOutputs[i]; 455193323Sed } 456193323Sed LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); 457198090Srdivacky Reloads.push_back(load); 458193323Sed codeReplacer->getInstList().push_back(load); 459193323Sed std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end()); 460193323Sed for (unsigned u = 0, e = Users.size(); u != e; ++u) { 461193323Sed Instruction *inst = cast<Instruction>(Users[u]); 462193323Sed if (!BlocksToExtract.count(inst->getParent())) 463193323Sed inst->replaceUsesOfWith(outputs[i], load); 464193323Sed } 465193323Sed } 466193323Sed 467193323Sed // Now we can emit a switch statement using the call as a value. 468193323Sed SwitchInst *TheSwitch = 469198090Srdivacky SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), 470193323Sed codeReplacer, 0, codeReplacer); 471193323Sed 472193323Sed // Since there may be multiple exits from the original region, make the new 473193323Sed // function return an unsigned, switch on that number. This loop iterates 474193323Sed // over all of the blocks in the extracted region, updating any terminator 475193323Sed // instructions in the to-be-extracted region that branch to blocks that are 476193323Sed // not in the region to be extracted. 477193323Sed std::map<BasicBlock*, BasicBlock*> ExitBlockMap; 478193323Sed 479193323Sed unsigned switchVal = 0; 480202375Srdivacky for (SetVector<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), 481193323Sed e = BlocksToExtract.end(); i != e; ++i) { 482193323Sed TerminatorInst *TI = (*i)->getTerminator(); 483193323Sed for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 484193323Sed if (!BlocksToExtract.count(TI->getSuccessor(i))) { 485193323Sed BasicBlock *OldTarget = TI->getSuccessor(i); 486193323Sed // add a new basic block which returns the appropriate value 487193323Sed BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; 488193323Sed if (!NewTarget) { 489193323Sed // If we don't already have an exit stub for this non-extracted 490193323Sed // destination, create one now! 491198090Srdivacky NewTarget = BasicBlock::Create(Context, 492198090Srdivacky OldTarget->getName() + ".exitStub", 493193323Sed newFunction); 494193323Sed unsigned SuccNum = switchVal++; 495193323Sed 496193323Sed Value *brVal = 0; 497193323Sed switch (NumExitBlocks) { 498193323Sed case 0: 499193323Sed case 1: break; // No value needed. 500193323Sed case 2: // Conditional branch, return a bool 501198090Srdivacky brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); 502193323Sed break; 503193323Sed default: 504198090Srdivacky brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); 505193323Sed break; 506193323Sed } 507193323Sed 508198090Srdivacky ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); 509193323Sed 510193323Sed // Update the switch instruction. 511198090Srdivacky TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), 512198090Srdivacky SuccNum), 513193323Sed OldTarget); 514193323Sed 515193323Sed // Restore values just before we exit 516193323Sed Function::arg_iterator OAI = OutputArgBegin; 517193323Sed for (unsigned out = 0, e = outputs.size(); out != e; ++out) { 518193323Sed // For an invoke, the normal destination is the only one that is 519193323Sed // dominated by the result of the invocation 520193323Sed BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent(); 521193323Sed 522193323Sed bool DominatesDef = true; 523193323Sed 524193323Sed if (InvokeInst *Invoke = dyn_cast<InvokeInst>(outputs[out])) { 525193323Sed DefBlock = Invoke->getNormalDest(); 526193323Sed 527193323Sed // Make sure we are looking at the original successor block, not 528193323Sed // at a newly inserted exit block, which won't be in the dominator 529193323Sed // info. 530193323Sed for (std::map<BasicBlock*, BasicBlock*>::iterator I = 531193323Sed ExitBlockMap.begin(), E = ExitBlockMap.end(); I != E; ++I) 532193323Sed if (DefBlock == I->second) { 533193323Sed DefBlock = I->first; 534193323Sed break; 535193323Sed } 536193323Sed 537193323Sed // In the extract block case, if the block we are extracting ends 538193323Sed // with an invoke instruction, make sure that we don't emit a 539193323Sed // store of the invoke value for the unwind block. 540193323Sed if (!DT && DefBlock != OldTarget) 541193323Sed DominatesDef = false; 542193323Sed } 543193323Sed 544198090Srdivacky if (DT) { 545193323Sed DominatesDef = DT->dominates(DefBlock, OldTarget); 546198090Srdivacky 547198090Srdivacky // If the output value is used by a phi in the target block, 548198090Srdivacky // then we need to test for dominance of the phi's predecessor 549198090Srdivacky // instead. Unfortunately, this a little complicated since we 550198090Srdivacky // have already rewritten uses of the value to uses of the reload. 551198090Srdivacky BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], 552198090Srdivacky OldTarget); 553198090Srdivacky if (pred && DT && DT->dominates(DefBlock, pred)) 554198090Srdivacky DominatesDef = true; 555198090Srdivacky } 556193323Sed 557193323Sed if (DominatesDef) { 558193323Sed if (AggregateArgs) { 559193323Sed Value *Idx[2]; 560198090Srdivacky Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 561198090Srdivacky Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), 562198090Srdivacky FirstOut+out); 563193323Sed GetElementPtrInst *GEP = 564193323Sed GetElementPtrInst::Create(OAI, Idx, Idx + 2, 565193323Sed "gep_" + outputs[out]->getName(), 566193323Sed NTRet); 567193323Sed new StoreInst(outputs[out], GEP, NTRet); 568193323Sed } else { 569193323Sed new StoreInst(outputs[out], OAI, NTRet); 570193323Sed } 571193323Sed } 572193323Sed // Advance output iterator even if we don't emit a store 573193323Sed if (!AggregateArgs) ++OAI; 574193323Sed } 575193323Sed } 576193323Sed 577193323Sed // rewrite the original branch instruction with this new target 578193323Sed TI->setSuccessor(i, NewTarget); 579193323Sed } 580193323Sed } 581193323Sed 582193323Sed // Now that we've done the deed, simplify the switch instruction. 583193323Sed const Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); 584193323Sed switch (NumExitBlocks) { 585193323Sed case 0: 586193323Sed // There are no successors (the block containing the switch itself), which 587193323Sed // means that previously this was the last part of the function, and hence 588193323Sed // this should be rewritten as a `ret' 589193323Sed 590193323Sed // Check if the function should return a value 591202375Srdivacky if (OldFnRetTy->isVoidTy()) { 592198090Srdivacky ReturnInst::Create(Context, 0, TheSwitch); // Return void 593193323Sed } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { 594193323Sed // return what we have 595198090Srdivacky ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); 596193323Sed } else { 597193323Sed // Otherwise we must have code extracted an unwind or something, just 598193323Sed // return whatever we want. 599198090Srdivacky ReturnInst::Create(Context, 600198090Srdivacky Constant::getNullValue(OldFnRetTy), TheSwitch); 601193323Sed } 602193323Sed 603193323Sed TheSwitch->eraseFromParent(); 604193323Sed break; 605193323Sed case 1: 606193323Sed // Only a single destination, change the switch into an unconditional 607193323Sed // branch. 608193323Sed BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); 609193323Sed TheSwitch->eraseFromParent(); 610193323Sed break; 611193323Sed case 2: 612193323Sed BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), 613193323Sed call, TheSwitch); 614193323Sed TheSwitch->eraseFromParent(); 615193323Sed break; 616193323Sed default: 617193323Sed // Otherwise, make the default destination of the switch instruction be one 618193323Sed // of the other successors. 619193323Sed TheSwitch->setOperand(0, call); 620193323Sed TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(NumExitBlocks)); 621193323Sed TheSwitch->removeCase(NumExitBlocks); // Remove redundant case 622193323Sed break; 623193323Sed } 624193323Sed} 625193323Sed 626193323Sedvoid CodeExtractor::moveCodeToFunction(Function *newFunction) { 627193323Sed Function *oldFunc = (*BlocksToExtract.begin())->getParent(); 628193323Sed Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); 629193323Sed Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); 630193323Sed 631202375Srdivacky for (SetVector<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), 632193323Sed e = BlocksToExtract.end(); i != e; ++i) { 633193323Sed // Delete the basic block from the old function, and the list of blocks 634193323Sed oldBlocks.remove(*i); 635193323Sed 636193323Sed // Insert this basic block into the new function 637193323Sed newBlocks.push_back(*i); 638193323Sed } 639193323Sed} 640193323Sed 641193323Sed/// ExtractRegion - Removes a loop from a function, replaces it with a call to 642193323Sed/// new function. Returns pointer to the new function. 643193323Sed/// 644193323Sed/// algorithm: 645193323Sed/// 646193323Sed/// find inputs and outputs for the region 647193323Sed/// 648193323Sed/// for inputs: add to function as args, map input instr* to arg# 649193323Sed/// for outputs: add allocas for scalars, 650193323Sed/// add to func as args, map output instr* to arg# 651193323Sed/// 652193323Sed/// rewrite func to use argument #s instead of instr* 653193323Sed/// 654193323Sed/// for each scalar output in the function: at every exit, store intermediate 655193323Sed/// computed result back into memory. 656193323Sed/// 657193323SedFunction *CodeExtractor:: 658193323SedExtractCodeRegion(const std::vector<BasicBlock*> &code) { 659193323Sed if (!isEligible(code)) 660193323Sed return 0; 661193323Sed 662193323Sed // 1) Find inputs, outputs 663193323Sed // 2) Construct new function 664193323Sed // * Add allocas for defs, pass as args by reference 665193323Sed // * Pass in uses as args 666193323Sed // 3) Move code region, add call instr to func 667193323Sed // 668193323Sed BlocksToExtract.insert(code.begin(), code.end()); 669193323Sed 670193323Sed Values inputs, outputs; 671193323Sed 672193323Sed // Assumption: this is a single-entry code region, and the header is the first 673193323Sed // block in the region. 674193323Sed BasicBlock *header = code[0]; 675193323Sed 676193323Sed for (unsigned i = 1, e = code.size(); i != e; ++i) 677193323Sed for (pred_iterator PI = pred_begin(code[i]), E = pred_end(code[i]); 678193323Sed PI != E; ++PI) 679193323Sed assert(BlocksToExtract.count(*PI) && 680193323Sed "No blocks in this region may have entries from outside the region" 681193323Sed " except for the first block!"); 682193323Sed 683193323Sed // If we have to split PHI nodes or the entry block, do so now. 684193323Sed severSplitPHINodes(header); 685193323Sed 686193323Sed // If we have any return instructions in the region, split those blocks so 687193323Sed // that the return is not in the region. 688193323Sed splitReturnBlocks(); 689193323Sed 690193323Sed Function *oldFunction = header->getParent(); 691193323Sed 692193323Sed // This takes place of the original loop 693198090Srdivacky BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), 694198090Srdivacky "codeRepl", oldFunction, 695193323Sed header); 696193323Sed 697193323Sed // The new function needs a root node because other nodes can branch to the 698193323Sed // head of the region, but the entry node of a function cannot have preds. 699198090Srdivacky BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), 700198090Srdivacky "newFuncRoot"); 701193323Sed newFuncRoot->getInstList().push_back(BranchInst::Create(header)); 702193323Sed 703193323Sed // Find inputs to, outputs from the code region. 704193323Sed findInputsOutputs(inputs, outputs); 705193323Sed 706193323Sed // Construct new function based on inputs/outputs & add allocas for all defs. 707193323Sed Function *newFunction = constructFunction(inputs, outputs, header, 708193323Sed newFuncRoot, 709193323Sed codeReplacer, oldFunction, 710193323Sed oldFunction->getParent()); 711193323Sed 712193323Sed emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); 713193323Sed 714193323Sed moveCodeToFunction(newFunction); 715193323Sed 716193323Sed // Loop over all of the PHI nodes in the header block, and change any 717193323Sed // references to the old incoming edge to be the new incoming edge. 718193323Sed for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { 719193323Sed PHINode *PN = cast<PHINode>(I); 720193323Sed for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 721193323Sed if (!BlocksToExtract.count(PN->getIncomingBlock(i))) 722193323Sed PN->setIncomingBlock(i, newFuncRoot); 723193323Sed } 724193323Sed 725193323Sed // Look at all successors of the codeReplacer block. If any of these blocks 726193323Sed // had PHI nodes in them, we need to update the "from" block to be the code 727193323Sed // replacer, not the original block in the extracted region. 728193323Sed std::vector<BasicBlock*> Succs(succ_begin(codeReplacer), 729193323Sed succ_end(codeReplacer)); 730193323Sed for (unsigned i = 0, e = Succs.size(); i != e; ++i) 731193323Sed for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { 732193323Sed PHINode *PN = cast<PHINode>(I); 733193323Sed std::set<BasicBlock*> ProcessedPreds; 734193323Sed for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 735193323Sed if (BlocksToExtract.count(PN->getIncomingBlock(i))) { 736193323Sed if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) 737193323Sed PN->setIncomingBlock(i, codeReplacer); 738193323Sed else { 739193323Sed // There were multiple entries in the PHI for this block, now there 740193323Sed // is only one, so remove the duplicated entries. 741193323Sed PN->removeIncomingValue(i, false); 742193323Sed --i; --e; 743193323Sed } 744193323Sed } 745193323Sed } 746193323Sed 747193323Sed //cerr << "NEW FUNCTION: " << *newFunction; 748193323Sed // verifyFunction(*newFunction); 749193323Sed 750193323Sed // cerr << "OLD FUNCTION: " << *oldFunction; 751193323Sed // verifyFunction(*oldFunction); 752193323Sed 753198090Srdivacky DEBUG(if (verifyFunction(*newFunction)) 754207618Srdivacky report_fatal_error("verifyFunction failed!")); 755193323Sed return newFunction; 756193323Sed} 757193323Sed 758193323Sedbool CodeExtractor::isEligible(const std::vector<BasicBlock*> &code) { 759193323Sed // Deny code region if it contains allocas or vastarts. 760193323Sed for (std::vector<BasicBlock*>::const_iterator BB = code.begin(), e=code.end(); 761193323Sed BB != e; ++BB) 762193323Sed for (BasicBlock::const_iterator I = (*BB)->begin(), Ie = (*BB)->end(); 763193323Sed I != Ie; ++I) 764193323Sed if (isa<AllocaInst>(*I)) 765193323Sed return false; 766193323Sed else if (const CallInst *CI = dyn_cast<CallInst>(I)) 767193323Sed if (const Function *F = CI->getCalledFunction()) 768193323Sed if (F->getIntrinsicID() == Intrinsic::vastart) 769193323Sed return false; 770193323Sed return true; 771193323Sed} 772193323Sed 773193323Sed 774193323Sed/// ExtractCodeRegion - slurp a sequence of basic blocks into a brand new 775193323Sed/// function 776193323Sed/// 777193323SedFunction* llvm::ExtractCodeRegion(DominatorTree &DT, 778193323Sed const std::vector<BasicBlock*> &code, 779193323Sed bool AggregateArgs) { 780193323Sed return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(code); 781193323Sed} 782193323Sed 783193323Sed/// ExtractBasicBlock - slurp a natural loop into a brand new function 784193323Sed/// 785193323SedFunction* llvm::ExtractLoop(DominatorTree &DT, Loop *L, bool AggregateArgs) { 786193323Sed return CodeExtractor(&DT, AggregateArgs).ExtractCodeRegion(L->getBlocks()); 787193323Sed} 788193323Sed 789193323Sed/// ExtractBasicBlock - slurp a basic block into a brand new function 790193323Sed/// 791193323SedFunction* llvm::ExtractBasicBlock(BasicBlock *BB, bool AggregateArgs) { 792193323Sed std::vector<BasicBlock*> Blocks; 793193323Sed Blocks.push_back(BB); 794193323Sed return CodeExtractor(0, AggregateArgs).ExtractCodeRegion(Blocks); 795193323Sed} 796