LowerSwitch.cpp revision 198892
11844Swollman//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 250476Speter// 31844Swollman// The LLVM Compiler Infrastructure 41638Srgrimes// 594940Sru// This file is distributed under the University of Illinois Open Source 61638Srgrimes// License. See LICENSE.TXT for details. 742915Sjdp// 842915Sjdp//===----------------------------------------------------------------------===// 942915Sjdp// 1042915Sjdp// The LowerSwitch transformation rewrites switch instructions with a sequence 11139106Sru// of branches, which allows targets to get away with not implementing the 1242915Sjdp// switch instruction until it is convenient. 1342915Sjdp// 1442915Sjdp//===----------------------------------------------------------------------===// 15129024Sdes 16129024Sdes#include "llvm/Transforms/Scalar.h" 1729141Speter#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 18129024Sdes#include "llvm/Constants.h" 19129024Sdes#include "llvm/Function.h" 20129024Sdes#include "llvm/Instructions.h" 21125119Sru#include "llvm/LLVMContext.h" 22100332Sru#include "llvm/Pass.h" 23100332Sru#include "llvm/ADT/STLExtras.h" 2442915Sjdp#include "llvm/Support/Compiler.h" 2542915Sjdp#include "llvm/Support/Debug.h" 2629141Speter#include "llvm/Support/raw_ostream.h" 27119607Sru#include <algorithm> 28117034Sgordonusing namespace llvm; 29119607Sru 30117034Sgordonnamespace { 31162210Simp /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 32162210Simp /// instructions. Note that this cannot be a BasicBlock pass because it 33162293Sobrien /// modifies the CFG! 34162210Simp class LowerSwitch : public FunctionPass { 35162210Simp public: 362827Sjkh static char ID; // Pass identification, replacement for typeid 372827Sjkh LowerSwitch() : FunctionPass(&ID) {} 38179184Sjb 39241711Sjhb virtual bool runOnFunction(Function &F); 40179184Sjb 412827Sjkh virtual void getAnalysisUsage(AnalysisUsage &AU) const { 42179184Sjb // This is a cluster of orthogonal Transforms 432827Sjkh AU.addPreserved<UnifyFunctionExitNodes>(); 442827Sjkh AU.addPreservedID(PromoteMemoryToRegisterID); 451638Srgrimes AU.addPreservedID(LowerInvokePassID); 462827Sjkh } 471638Srgrimes 4818529Sbde struct CaseRange { 4918529Sbde Constant* Low; 501638Srgrimes Constant* High; 5142450Sjdp BasicBlock* BB; 521638Srgrimes 53220755Sdim CaseRange() : Low(0), High(0), BB(0) { } 541638Srgrimes CaseRange(Constant* low, Constant* high, BasicBlock* bb) : 5596512Sru Low(low), High(high), BB(bb) { } 56211725Simp }; 5796512Sru 5896512Sru typedef std::vector<CaseRange> CaseVector; 5996512Sru typedef std::vector<CaseRange>::iterator CaseItr; 6096512Sru private: 6196512Sru void processSwitchInst(SwitchInst *SI); 6296512Sru 63126890Strhodes BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val, 64126890Strhodes BasicBlock* OrigBlock, BasicBlock* Default); 65236114Sdes BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val, 66236114Sdes BasicBlock* OrigBlock, BasicBlock* Default); 67241711Sjhb unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); 68236114Sdes }; 691638Srgrimes 70236114Sdes /// The comparison function for sorting the switch case values in the vector. 71241711Sjhb /// WARNING: Case ranges should be disjoint! 721638Srgrimes struct CaseCmp { 7342450Sjdp bool operator () (const LowerSwitch::CaseRange& C1, 74236114Sdes const LowerSwitch::CaseRange& C2) { 75241711Sjhb 761844Swollman const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 77237814Sdim const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 78236114Sdes return CI1->getValue().slt(CI2->getValue()); 79236114Sdes } 8036673Sdt }; 81236114Sdes} 821844Swollman 8342450Sjdpchar LowerSwitch::ID = 0; 84236114Sdesstatic RegisterPass<LowerSwitch> 851844SwollmanX("lowerswitch", "Lower SwitchInst's to branches"); 861844Swollman 87127027Strhodes// Publically exposed interface to pass... 88241711Sjhbconst PassInfo *const llvm::LowerSwitchID = &X; 891844Swollman// createLowerSwitchPass - Interface to this file... 9042450SjdpFunctionPass *llvm::createLowerSwitchPass() { 911844Swollman return new LowerSwitch(); 92241711Sjhb} 931844Swollman 94117173Srubool LowerSwitch::runOnFunction(Function &F) { 95117159Sru bool Changed = false; 96241711Sjhb 971638Srgrimes for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 98117173Sru BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 99217100Skib 100217100Skib if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 101241711Sjhb Changed = true; 102117173Sru processSwitchInst(SI); 103117173Sru } 104217100Skib } 105117173Sru 106241711Sjhb return Changed; 107117173Sru} 1081844Swollman 109217100Skib// operator<< - Used for debugging purposes. 110241711Sjhb// 1111844Swollmanstatic raw_ostream& operator<<(raw_ostream &O, 11242450Sjdp const LowerSwitch::CaseVector &C) ATTRIBUTE_USED; 113217100Skibstatic raw_ostream& operator<<(raw_ostream &O, 114241711Sjhb const LowerSwitch::CaseVector &C) { 1151844Swollman O << "["; 11696512Sru 1171638Srgrimes for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 118168317Skan E = C.end(); B != E; ) { 119156772Sdeischen O << *B->Low << " -" << *B->High; 120178047Skan if (++B != E) O << ", "; 121168317Skan } 122169822Sru 123156772Sdeischen return O << "]"; 124156772Sdeischen} 125156772Sdeischen 126156772Sdeischen// switchConvert - Convert the switch statement into a binary lookup of 12799362Sru// the case values. The function recursively builds this tree. 12899362Sru// 12999362SruBasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 13099362Sru Value* Val, BasicBlock* OrigBlock, 13196512Sru BasicBlock* Default) 13296512Sru{ 1331638Srgrimes unsigned Size = End - Begin; 13496512Sru 13596512Sru if (Size == 1) 13696512Sru return newLeafBlock(*Begin, Val, OrigBlock, Default); 137163683Sru 138163683Sru unsigned Mid = Size / 2; 139163683Sru std::vector<CaseRange> LHS(Begin, Begin + Mid); 140163683Sru DEBUG(errs() << "LHS: " << LHS << "\n"); 141163683Sru std::vector<CaseRange> RHS(Begin + Mid, End); 14296512Sru DEBUG(errs() << "RHS: " << RHS << "\n"); 14399362Sru 1441638Srgrimes CaseRange& Pivot = *(Begin + Mid); 14596512Sru DEBUG(errs() << "Pivot ==> " 14695114Sobrien << cast<ConstantInt>(Pivot.Low)->getValue() << " -" 147156854Sru << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 14896512Sru 14996512Sru BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, 15095306Sru OrigBlock, Default); 15196512Sru BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val, 15296512Sru OrigBlock, Default); 15396512Sru 154163683Sru // Create a new node that checks if the value is < pivot. Go to the 155163683Sru // left branch if it is and right branch if not. 156163683Sru Function* F = OrigBlock->getParent(); 157163683Sru BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 158163683Sru Function::iterator FI = OrigBlock; 15996512Sru F->getBasicBlockList().insert(++FI, NewNode); 16074805Sru 1611844Swollman ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 16299362Sru Val, Pivot.Low, "Pivot"); 16399362Sru NewNode->getInstList().push_back(Comp); 16496512Sru BranchInst::Create(LBranch, RBranch, Comp, NewNode); 16599362Sru return NewNode; 1661844Swollman} 16796512Sru 16896512Sru// newLeafBlock - Create a new leaf block for the binary lookup tree. It 1691638Srgrimes// checks if the switch's value == the case's value. If not, then it 170229380Skib// jumps to the default branch. At this point in the tree, the value 171229380Skib// can't be another valid case value, so the jump to the "default" branch 172229380Skib// is warranted. 173229380Skib// 174229380SkibBasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 175212423Srpaulo BasicBlock* OrigBlock, 176212423Srpaulo BasicBlock* Default) 177212423Srpaulo{ 17842915Sjdp Function* F = OrigBlock->getParent(); 179212423Srpaulo BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 18042915Sjdp Function::iterator FI = OrigBlock; 18196512Sru F->getBasicBlockList().insert(++FI, NewLeaf); 18242915Sjdp 18396512Sru // Emit comparison 18442915Sjdp ICmpInst* Comp = NULL; 185163683Sru if (Leaf.Low == Leaf.High) { 186229380Skib // Make the seteq instruction... 18796512Sru Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 188163683Sru Leaf.Low, "SwitchLeaf"); 189163683Sru } else { 190229380Skib // Make range comparison 191163683Sru if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 192163683Sru // Val >= Min && Val <= Hi --> Val <= Hi 19328945Speter Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 194241711Sjhb "SwitchLeaf"); 195241711Sjhb } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 196163683Sru // Val >= 0 && Val <= Hi --> Val <=u Hi 197241711Sjhb Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 1981844Swollman "SwitchLeaf"); 199156813Sru } else { 20096512Sru // Emit V-Lo <=u Hi-Lo 20196512Sru Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 20296512Sru Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 2032353Sbde Val->getName()+".off", 20496512Sru NewLeaf); 20596512Sru Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 20696512Sru Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 2073859Sbde "SwitchLeaf"); 2081844Swollman } 209139106Sru } 21096512Sru 21196512Sru // Make the conditional branch... 21296512Sru BasicBlock* Succ = Leaf.BB; 21396512Sru BranchInst::Create(Succ, Default, Comp, NewLeaf); 21492491Smarkm 21596512Sru // If there were any PHI nodes in this successor, rewrite one entry 21696512Sru // from OrigBlock to come from NewLeaf. 21792491Smarkm for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 21892491Smarkm PHINode* PN = cast<PHINode>(I); 2191638Srgrimes // Remove all but one incoming entries from the cluster 220144893Sharti uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 22196512Sru cast<ConstantInt>(Leaf.Low)->getSExtValue(); 22296512Sru for (uint64_t j = 0; j < Range; ++j) { 22396512Sru PN->removeIncomingValue(OrigBlock); 224156813Sru } 22596512Sru 2261638Srgrimes int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 2271638Srgrimes assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 22834179Sbde PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 22924750Sbde } 23042450Sjdp 23124750Sbde return NewLeaf; 23224750Sbde} 233139107Sru 23431809Sbde// Clusterify - Transform simple list of Cases into list of CaseRange's 23542915Sjdpunsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 23627910Sasami unsigned numCmps = 0; 23728945Speter 2381638Srgrimes // Start with "simple" cases 2391638Srgrimes for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) 2401638Srgrimes Cases.push_back(CaseRange(SI->getSuccessorValue(i), 241136019Sru SI->getSuccessorValue(i), 242139111Sru SI->getSuccessor(i))); 2432298Swollman std::sort(Cases.begin(), Cases.end(), CaseCmp()); 2442298Swollman 245136019Sru // Merge case into clusters 246136019Sru if (Cases.size()>=2) 2472298Swollman for (CaseItr I=Cases.begin(), J=next(Cases.begin()); J!=Cases.end(); ) { 24849328Shoek int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 24949328Shoek int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 25049328Shoek BasicBlock* nextBB = J->BB; 25149328Shoek BasicBlock* currentBB = I->BB; 25256971Sru 25349328Shoek // If the two neighboring cases go to the same destination, merge them 25449328Shoek // into a single case. 25549328Shoek if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 25649328Shoek I->High = J->High; 25799362Sru J = Cases.erase(J); 25895306Sru } else { 25999343Sru I = J++; 26095306Sru } 261172832Sru } 26292980Sdes 26349328Shoek for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 26496512Sru if (I->Low != I->High) 265156854Sru // A range counts double, since it requires two compares. 26692980Sdes ++numCmps; 26749328Shoek } 2681638Srgrimes 269116144Sobrien return numCmps; 270100872Sru} 27149328Shoek 27242915Sjdp// processSwitchInst - Replace the specified switch instruction with a sequence 27342915Sjdp// of chained if-then insts in a balanced binary search. 274235534Sjlh// 275235534Sjlhvoid LowerSwitch::processSwitchInst(SwitchInst *SI) { 276235534Sjlh BasicBlock *CurBlock = SI->getParent(); 277235534Sjlh BasicBlock *OrigBlock = CurBlock; 278235534Sjlh Function *F = CurBlock->getParent(); 279235534Sjlh Value *Val = SI->getOperand(0); // The value we are switching on... 280235534Sjlh BasicBlock* Default = SI->getDefaultDest(); 281235534Sjlh 282235534Sjlh // If there is only the default destination, don't bother with the code below. 283235534Sjlh if (SI->getNumOperands() == 2) { 284235534Sjlh BranchInst::Create(SI->getDefaultDest(), CurBlock); 285235534Sjlh CurBlock->getInstList().erase(SI); 286235534Sjlh return; 287235534Sjlh } 288235534Sjlh 289235534Sjlh // Create a new, empty default block so that the new hierarchy of 290235534Sjlh // if-then statements go to this and the PHI nodes are happy. 291235534Sjlh BasicBlock* NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 292235534Sjlh F->getBasicBlockList().insert(Default, NewDefault); 293119846Sru 294119846Sru BranchInst::Create(Default, NewDefault); 295119846Sru 296119846Sru // If there is an entry in any PHI nodes for the default edge, make sure 297119730Speter // to update them as well. 298119846Sru for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 299119846Sru PHINode *PN = cast<PHINode>(I); 300119846Sru int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 3011844Swollman assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 30228945Speter PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 303235534Sjlh } 304235534Sjlh 305235534Sjlh // Prepare cases vector. 306156813Sru CaseVector Cases; 307100872Sru unsigned numCmps = Clusterify(Cases, SI); 30849328Shoek 3091844Swollman DEBUG(errs() << "Clusterify finished. Total clusters: " << Cases.size() 310139106Sru << ". Total compares: " << numCmps << "\n"); 311100872Sru DEBUG(errs() << "Cases: " << Cases << "\n"); 31296462Sru (void)numCmps; 31396462Sru 314144893Sharti BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, 31596462Sru OrigBlock, NewDefault); 316248350Sbrooks 317141503Sphantom // Branch to our shiny new if-then stuff... 31897769Sru BranchInst::Create(SwitchBlock, OrigBlock); 31996668Sru 320248350Sbrooks // We are now done with the switch instruction, delete it. 321248350Sbrooks CurBlock->getInstList().erase(SI); 32299256Sru} 32396462Sru