1194178Sed//===- PartialInlining.cpp - Inline parts of functions --------------------===//
2194178Sed//
3194178Sed//                     The LLVM Compiler Infrastructure
4194178Sed//
5194178Sed// This file is distributed under the University of Illinois Open Source
6194178Sed// License. See LICENSE.TXT for details.
7194178Sed//
8194178Sed//===----------------------------------------------------------------------===//
9194178Sed//
10194178Sed// This pass performs partial inlining, typically by inlining an if statement
11194178Sed// that surrounds the body of the function.
12194178Sed//
13194178Sed//===----------------------------------------------------------------------===//
14194178Sed
15194178Sed#define DEBUG_TYPE "partialinlining"
16194178Sed#include "llvm/Transforms/IPO.h"
17249423Sdim#include "llvm/ADT/Statistic.h"
18249423Sdim#include "llvm/Analysis/Dominators.h"
19249423Sdim#include "llvm/IR/Instructions.h"
20249423Sdim#include "llvm/IR/Module.h"
21194178Sed#include "llvm/Pass.h"
22249423Sdim#include "llvm/Support/CFG.h"
23194178Sed#include "llvm/Transforms/Utils/Cloning.h"
24239462Sdim#include "llvm/Transforms/Utils/CodeExtractor.h"
25194178Sedusing namespace llvm;
26194178Sed
27194612SedSTATISTIC(NumPartialInlined, "Number of functions partially inlined");
28194612Sed
29194178Sednamespace {
30198892Srdivacky  struct PartialInliner : public ModulePass {
31194178Sed    virtual void getAnalysisUsage(AnalysisUsage &AU) const { }
32194178Sed    static char ID; // Pass identification, replacement for typeid
33218893Sdim    PartialInliner() : ModulePass(ID) {
34218893Sdim      initializePartialInlinerPass(*PassRegistry::getPassRegistry());
35218893Sdim    }
36194178Sed
37194178Sed    bool runOnModule(Module& M);
38194178Sed
39194178Sed  private:
40194178Sed    Function* unswitchFunction(Function* F);
41194178Sed  };
42194178Sed}
43194178Sed
44194178Sedchar PartialInliner::ID = 0;
45212904SdimINITIALIZE_PASS(PartialInliner, "partial-inliner",
46218893Sdim                "Partial Inliner", false, false)
47194178Sed
48194178SedModulePass* llvm::createPartialInliningPass() { return new PartialInliner(); }
49194178Sed
50194178SedFunction* PartialInliner::unswitchFunction(Function* F) {
51194178Sed  // First, verify that this function is an unswitching candidate...
52194178Sed  BasicBlock* entryBlock = F->begin();
53198090Srdivacky  BranchInst *BR = dyn_cast<BranchInst>(entryBlock->getTerminator());
54198090Srdivacky  if (!BR || BR->isUnconditional())
55194178Sed    return 0;
56194178Sed
57194178Sed  BasicBlock* returnBlock = 0;
58194178Sed  BasicBlock* nonReturnBlock = 0;
59194178Sed  unsigned returnCount = 0;
60194178Sed  for (succ_iterator SI = succ_begin(entryBlock), SE = succ_end(entryBlock);
61194178Sed       SI != SE; ++SI)
62194178Sed    if (isa<ReturnInst>((*SI)->getTerminator())) {
63194178Sed      returnBlock = *SI;
64194178Sed      returnCount++;
65194178Sed    } else
66194178Sed      nonReturnBlock = *SI;
67194178Sed
68194178Sed  if (returnCount != 1)
69194178Sed    return 0;
70194178Sed
71194178Sed  // Clone the function, so that we can hack away on it.
72218893Sdim  ValueToValueMapTy VMap;
73212904Sdim  Function* duplicateFunction = CloneFunction(F, VMap,
74212904Sdim                                              /*ModuleLevelChanges=*/false);
75194178Sed  duplicateFunction->setLinkage(GlobalValue::InternalLinkage);
76194178Sed  F->getParent()->getFunctionList().push_back(duplicateFunction);
77210299Sed  BasicBlock* newEntryBlock = cast<BasicBlock>(VMap[entryBlock]);
78210299Sed  BasicBlock* newReturnBlock = cast<BasicBlock>(VMap[returnBlock]);
79210299Sed  BasicBlock* newNonReturnBlock = cast<BasicBlock>(VMap[nonReturnBlock]);
80194178Sed
81194178Sed  // Go ahead and update all uses to the duplicate, so that we can just
82194178Sed  // use the inliner functionality when we're done hacking.
83194178Sed  F->replaceAllUsesWith(duplicateFunction);
84194178Sed
85194178Sed  // Special hackery is needed with PHI nodes that have inputs from more than
86194178Sed  // one extracted block.  For simplicity, just split the PHIs into a two-level
87194178Sed  // sequence of PHIs, some of which will go in the extracted region, and some
88194178Sed  // of which will go outside.
89194178Sed  BasicBlock* preReturn = newReturnBlock;
90194178Sed  newReturnBlock = newReturnBlock->splitBasicBlock(
91194178Sed                                              newReturnBlock->getFirstNonPHI());
92194178Sed  BasicBlock::iterator I = preReturn->begin();
93194178Sed  BasicBlock::iterator Ins = newReturnBlock->begin();
94194178Sed  while (I != preReturn->end()) {
95194178Sed    PHINode* OldPhi = dyn_cast<PHINode>(I);
96194178Sed    if (!OldPhi) break;
97194178Sed
98221345Sdim    PHINode* retPhi = PHINode::Create(OldPhi->getType(), 2, "", Ins);
99194178Sed    OldPhi->replaceAllUsesWith(retPhi);
100194178Sed    Ins = newReturnBlock->getFirstNonPHI();
101194178Sed
102194178Sed    retPhi->addIncoming(I, preReturn);
103194178Sed    retPhi->addIncoming(OldPhi->getIncomingValueForBlock(newEntryBlock),
104194178Sed                        newEntryBlock);
105194178Sed    OldPhi->removeIncomingValue(newEntryBlock);
106194178Sed
107194178Sed    ++I;
108194178Sed  }
109194178Sed  newEntryBlock->getTerminator()->replaceUsesOfWith(preReturn, newReturnBlock);
110194178Sed
111194178Sed  // Gather up the blocks that we're going to extract.
112194178Sed  std::vector<BasicBlock*> toExtract;
113194178Sed  toExtract.push_back(newNonReturnBlock);
114194178Sed  for (Function::iterator FI = duplicateFunction->begin(),
115194178Sed       FE = duplicateFunction->end(); FI != FE; ++FI)
116194178Sed    if (&*FI != newEntryBlock && &*FI != newReturnBlock &&
117194178Sed        &*FI != newNonReturnBlock)
118194178Sed      toExtract.push_back(FI);
119194178Sed
120194178Sed  // The CodeExtractor needs a dominator tree.
121194178Sed  DominatorTree DT;
122194178Sed  DT.runOnFunction(*duplicateFunction);
123194178Sed
124203954Srdivacky  // Extract the body of the if.
125239462Sdim  Function* extractedFunction
126239462Sdim    = CodeExtractor(toExtract, &DT).extractCodeRegion();
127194178Sed
128207618Srdivacky  InlineFunctionInfo IFI;
129207618Srdivacky
130194178Sed  // Inline the top-level if test into all callers.
131194178Sed  std::vector<User*> Users(duplicateFunction->use_begin(),
132194178Sed                           duplicateFunction->use_end());
133194178Sed  for (std::vector<User*>::iterator UI = Users.begin(), UE = Users.end();
134194178Sed       UI != UE; ++UI)
135207618Srdivacky    if (CallInst *CI = dyn_cast<CallInst>(*UI))
136207618Srdivacky      InlineFunction(CI, IFI);
137207618Srdivacky    else if (InvokeInst *II = dyn_cast<InvokeInst>(*UI))
138207618Srdivacky      InlineFunction(II, IFI);
139194178Sed
140194178Sed  // Ditch the duplicate, since we're done with it, and rewrite all remaining
141194178Sed  // users (function pointers, etc.) back to the original function.
142194178Sed  duplicateFunction->replaceAllUsesWith(F);
143194178Sed  duplicateFunction->eraseFromParent();
144194178Sed
145194612Sed  ++NumPartialInlined;
146194612Sed
147194178Sed  return extractedFunction;
148194178Sed}
149194178Sed
150194178Sedbool PartialInliner::runOnModule(Module& M) {
151194178Sed  std::vector<Function*> worklist;
152194178Sed  worklist.reserve(M.size());
153194178Sed  for (Module::iterator FI = M.begin(), FE = M.end(); FI != FE; ++FI)
154194178Sed    if (!FI->use_empty() && !FI->isDeclaration())
155202375Srdivacky      worklist.push_back(&*FI);
156194178Sed
157194178Sed  bool changed = false;
158194178Sed  while (!worklist.empty()) {
159194178Sed    Function* currFunc = worklist.back();
160194178Sed    worklist.pop_back();
161194178Sed
162194178Sed    if (currFunc->use_empty()) continue;
163194178Sed
164194178Sed    bool recursive = false;
165194178Sed    for (Function::use_iterator UI = currFunc->use_begin(),
166194178Sed         UE = currFunc->use_end(); UI != UE; ++UI)
167212904Sdim      if (Instruction* I = dyn_cast<Instruction>(*UI))
168194178Sed        if (I->getParent()->getParent() == currFunc) {
169194178Sed          recursive = true;
170194178Sed          break;
171194178Sed        }
172194178Sed    if (recursive) continue;
173194178Sed
174194178Sed
175194178Sed    if (Function* newFunc = unswitchFunction(currFunc)) {
176194178Sed      worklist.push_back(newFunc);
177194178Sed      changed = true;
178194178Sed    }
179194178Sed
180194178Sed  }
181194178Sed
182194178Sed  return changed;
183195340Sed}
184