1194178Sed//===- PartialInlining.cpp - Inline parts of functions --------------------===// 2194178Sed// 3194178Sed// The LLVM Compiler Infrastructure 4194178Sed// 5194178Sed// This file is distributed under the University of Illinois Open Source 6194178Sed// License. See LICENSE.TXT for details. 7194178Sed// 8194178Sed//===----------------------------------------------------------------------===// 9194178Sed// 10194178Sed// This pass performs partial inlining, typically by inlining an if statement 11194178Sed// that surrounds the body of the function. 12194178Sed// 13194178Sed//===----------------------------------------------------------------------===// 14194178Sed 15194178Sed#define DEBUG_TYPE "partialinlining" 16194178Sed#include "llvm/Transforms/IPO.h" 17249423Sdim#include "llvm/ADT/Statistic.h" 18249423Sdim#include "llvm/Analysis/Dominators.h" 19249423Sdim#include "llvm/IR/Instructions.h" 20249423Sdim#include "llvm/IR/Module.h" 21194178Sed#include "llvm/Pass.h" 22249423Sdim#include "llvm/Support/CFG.h" 23194178Sed#include "llvm/Transforms/Utils/Cloning.h" 24239462Sdim#include "llvm/Transforms/Utils/CodeExtractor.h" 25194178Sedusing namespace llvm; 26194178Sed 27194612SedSTATISTIC(NumPartialInlined, "Number of functions partially inlined"); 28194612Sed 29194178Sednamespace { 30198892Srdivacky struct PartialInliner : public ModulePass { 31194178Sed virtual void getAnalysisUsage(AnalysisUsage &AU) const { } 32194178Sed static char ID; // Pass identification, replacement for typeid 33218893Sdim PartialInliner() : ModulePass(ID) { 34218893Sdim initializePartialInlinerPass(*PassRegistry::getPassRegistry()); 35218893Sdim } 36194178Sed 37194178Sed bool runOnModule(Module& M); 38194178Sed 39194178Sed private: 40194178Sed Function* unswitchFunction(Function* F); 41194178Sed }; 42194178Sed} 43194178Sed 44194178Sedchar PartialInliner::ID = 0; 45212904SdimINITIALIZE_PASS(PartialInliner, "partial-inliner", 46218893Sdim "Partial Inliner", false, false) 47194178Sed 48194178SedModulePass* llvm::createPartialInliningPass() { return new PartialInliner(); } 49194178Sed 50194178SedFunction* PartialInliner::unswitchFunction(Function* F) { 51194178Sed // First, verify that this function is an unswitching candidate... 52194178Sed BasicBlock* entryBlock = F->begin(); 53198090Srdivacky BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator()); 54198090Srdivacky if (!BR || BR->isUnconditional()) 55194178Sed return 0; 56194178Sed 57194178Sed BasicBlock* returnBlock = 0; 58194178Sed BasicBlock* nonReturnBlock = 0; 59194178Sed unsigned returnCount = 0; 60194178Sed for (succ_iterator SI = succ_begin(entryBlock), SE = succ_end(entryBlock); 61194178Sed SI != SE; ++SI) 62194178Sed if (isa<ReturnInst>((*SI)->getTerminator())) { 63194178Sed returnBlock = *SI; 64194178Sed returnCount++; 65194178Sed } else 66194178Sed nonReturnBlock = *SI; 67194178Sed 68194178Sed if (returnCount != 1) 69194178Sed return 0; 70194178Sed 71194178Sed // Clone the function, so that we can hack away on it. 72218893Sdim ValueToValueMapTy VMap; 73212904Sdim Function* duplicateFunction = CloneFunction(F, VMap, 74212904Sdim /*ModuleLevelChanges=*/false); 75194178Sed duplicateFunction->setLinkage(GlobalValue::InternalLinkage); 76194178Sed F->getParent()->getFunctionList().push_back(duplicateFunction); 77210299Sed BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]); 78210299Sed BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]); 79210299Sed BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]); 80194178Sed 81194178Sed // Go ahead and update all uses to the duplicate, so that we can just 82194178Sed // use the inliner functionality when we're done hacking. 83194178Sed F->replaceAllUsesWith(duplicateFunction); 84194178Sed 85194178Sed // Special hackery is needed with PHI nodes that have inputs from more than 86194178Sed // one extracted block. For simplicity, just split the PHIs into a two-level 87194178Sed // sequence of PHIs, some of which will go in the extracted region, and some 88194178Sed // of which will go outside. 89194178Sed BasicBlock* preReturn = newReturnBlock; 90194178Sed newReturnBlock = newReturnBlock->splitBasicBlock( 91194178Sed newReturnBlock->getFirstNonPHI()); 92194178Sed BasicBlock::iterator I = preReturn->begin(); 93194178Sed BasicBlock::iterator Ins = newReturnBlock->begin(); 94194178Sed while (I != preReturn->end()) { 95194178Sed PHINode* OldPhi = dyn_cast<PHINode>(I); 96194178Sed if (!OldPhi) break; 97194178Sed 98221345Sdim PHINode* retPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins); 99194178Sed OldPhi->replaceAllUsesWith(retPhi); 100194178Sed Ins = newReturnBlock->getFirstNonPHI(); 101194178Sed 102194178Sed retPhi->addIncoming(I, preReturn); 103194178Sed retPhi->addIncoming(OldPhi->getIncomingValueForBlock(newEntryBlock), 104194178Sed newEntryBlock); 105194178Sed OldPhi->removeIncomingValue(newEntryBlock); 106194178Sed 107194178Sed ++I; 108194178Sed } 109194178Sed newEntryBlock->getTerminator()->replaceUsesOfWith(preReturn, newReturnBlock); 110194178Sed 111194178Sed // Gather up the blocks that we're going to extract. 112194178Sed std::vector<BasicBlock*> toExtract; 113194178Sed toExtract.push_back(newNonReturnBlock); 114194178Sed for (Function::iterator FI = duplicateFunction->begin(), 115194178Sed FE = duplicateFunction->end(); FI != FE; ++FI) 116194178Sed if (&*FI != newEntryBlock && &*FI != newReturnBlock && 117194178Sed &*FI != newNonReturnBlock) 118194178Sed toExtract.push_back(FI); 119194178Sed 120194178Sed // The CodeExtractor needs a dominator tree. 121194178Sed DominatorTree DT; 122194178Sed DT.runOnFunction(*duplicateFunction); 123194178Sed 124203954Srdivacky // Extract the body of the if. 125239462Sdim Function* extractedFunction 126239462Sdim = CodeExtractor(toExtract, &DT).extractCodeRegion(); 127194178Sed 128207618Srdivacky InlineFunctionInfo IFI; 129207618Srdivacky 130194178Sed // Inline the top-level if test into all callers. 131194178Sed std::vector<User*> Users(duplicateFunction->use_begin(), 132194178Sed duplicateFunction->use_end()); 133194178Sed for (std::vector<User*>::iterator UI = Users.begin(), UE = Users.end(); 134194178Sed UI != UE; ++UI) 135207618Srdivacky if (CallInst *CI = dyn_cast<CallInst>(*UI)) 136207618Srdivacky InlineFunction(CI, IFI); 137207618Srdivacky else if (InvokeInst *II = dyn_cast<InvokeInst>(*UI)) 138207618Srdivacky InlineFunction(II, IFI); 139194178Sed 140194178Sed // Ditch the duplicate, since we're done with it, and rewrite all remaining 141194178Sed // users (function pointers, etc.) back to the original function. 142194178Sed duplicateFunction->replaceAllUsesWith(F); 143194178Sed duplicateFunction->eraseFromParent(); 144194178Sed 145194612Sed ++NumPartialInlined; 146194612Sed 147194178Sed return extractedFunction; 148194178Sed} 149194178Sed 150194178Sedbool PartialInliner::runOnModule(Module& M) { 151194178Sed std::vector<Function*> worklist; 152194178Sed worklist.reserve(M.size()); 153194178Sed for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI) 154194178Sed if (!FI->use_empty() && !FI->isDeclaration()) 155202375Srdivacky worklist.push_back(&*FI); 156194178Sed 157194178Sed bool changed = false; 158194178Sed while (!worklist.empty()) { 159194178Sed Function* currFunc = worklist.back(); 160194178Sed worklist.pop_back(); 161194178Sed 162194178Sed if (currFunc->use_empty()) continue; 163194178Sed 164194178Sed bool recursive = false; 165194178Sed for (Function::use_iterator UI = currFunc->use_begin(), 166194178Sed UE = currFunc->use_end(); UI != UE; ++UI) 167212904Sdim if (Instruction* I = dyn_cast<Instruction>(*UI)) 168194178Sed if (I->getParent()->getParent() == currFunc) { 169194178Sed recursive = true; 170194178Sed break; 171194178Sed } 172194178Sed if (recursive) continue; 173194178Sed 174194178Sed 175194178Sed if (Function* newFunc = unswitchFunction(currFunc)) { 176194178Sed worklist.push_back(newFunc); 177194178Sed changed = true; 178194178Sed } 179194178Sed 180194178Sed } 181194178Sed 182194178Sed return changed; 183195340Sed} 184