LowerSwitch.cpp revision 218893
1193323Sed//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2193323Sed// 3193323Sed// The LLVM Compiler Infrastructure 4193323Sed// 5193323Sed// This file is distributed under the University of Illinois Open Source 6193323Sed// License. See LICENSE.TXT for details. 7193323Sed// 8193323Sed//===----------------------------------------------------------------------===// 9193323Sed// 10193323Sed// The LowerSwitch transformation rewrites switch instructions with a sequence 11193323Sed// of branches, which allows targets to get away with not implementing the 12193323Sed// switch instruction until it is convenient. 13193323Sed// 14193323Sed//===----------------------------------------------------------------------===// 15193323Sed 16193323Sed#include "llvm/Transforms/Scalar.h" 17193323Sed#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 18193323Sed#include "llvm/Constants.h" 19193323Sed#include "llvm/Function.h" 20193323Sed#include "llvm/Instructions.h" 21198090Srdivacky#include "llvm/LLVMContext.h" 22193323Sed#include "llvm/Pass.h" 23193323Sed#include "llvm/ADT/STLExtras.h" 24198892Srdivacky#include "llvm/Support/Compiler.h" 25193323Sed#include "llvm/Support/Debug.h" 26193323Sed#include "llvm/Support/raw_ostream.h" 27193323Sed#include <algorithm> 28193323Sedusing namespace llvm; 29193323Sed 30193323Sednamespace { 31193323Sed /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 32212904Sdim /// instructions. 33198892Srdivacky class LowerSwitch : public FunctionPass { 34193323Sed public: 35193323Sed static char ID; // Pass identification, replacement for typeid 36218893Sdim LowerSwitch() : FunctionPass(ID) { 37218893Sdim initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 38218893Sdim } 39193323Sed 40193323Sed virtual bool runOnFunction(Function &F); 41193323Sed 42193323Sed virtual void getAnalysisUsage(AnalysisUsage &AU) const { 43193323Sed // This is a cluster of orthogonal Transforms 44193323Sed AU.addPreserved<UnifyFunctionExitNodes>(); 45212904Sdim AU.addPreserved("mem2reg"); 46193323Sed AU.addPreservedID(LowerInvokePassID); 47193323Sed } 48193323Sed 49193323Sed struct CaseRange { 50193323Sed Constant* Low; 51193323Sed Constant* High; 52193323Sed BasicBlock* BB; 53193323Sed 54212904Sdim CaseRange(Constant *low = 0, Constant *high = 0, BasicBlock *bb = 0) : 55193323Sed Low(low), High(high), BB(bb) { } 56193323Sed }; 57193323Sed 58193323Sed typedef std::vector<CaseRange> CaseVector; 59193323Sed typedef std::vector<CaseRange>::iterator CaseItr; 60193323Sed private: 61193323Sed void processSwitchInst(SwitchInst *SI); 62193323Sed 63193323Sed BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val, 64193323Sed BasicBlock* OrigBlock, BasicBlock* Default); 65193323Sed BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val, 66193323Sed BasicBlock* OrigBlock, BasicBlock* Default); 67193323Sed unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); 68193323Sed }; 69193323Sed 70193323Sed /// The comparison function for sorting the switch case values in the vector. 71193323Sed /// WARNING: Case ranges should be disjoint! 72193323Sed struct CaseCmp { 73193323Sed bool operator () (const LowerSwitch::CaseRange& C1, 74193323Sed const LowerSwitch::CaseRange& C2) { 75193323Sed 76193323Sed const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 77193323Sed const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 78193323Sed return CI1->getValue().slt(CI2->getValue()); 79193323Sed } 80193323Sed }; 81193323Sed} 82193323Sed 83193323Sedchar LowerSwitch::ID = 0; 84212904SdimINITIALIZE_PASS(LowerSwitch, "lowerswitch", 85218893Sdim "Lower SwitchInst's to branches", false, false) 86193323Sed 87193323Sed// Publically exposed interface to pass... 88212904Sdimchar &llvm::LowerSwitchID = LowerSwitch::ID; 89193323Sed// createLowerSwitchPass - Interface to this file... 90193323SedFunctionPass *llvm::createLowerSwitchPass() { 91193323Sed return new LowerSwitch(); 92193323Sed} 93193323Sed 94193323Sedbool LowerSwitch::runOnFunction(Function &F) { 95193323Sed bool Changed = false; 96193323Sed 97193323Sed for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 98193323Sed BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 99193323Sed 100193323Sed if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 101193323Sed Changed = true; 102193323Sed processSwitchInst(SI); 103193323Sed } 104193323Sed } 105193323Sed 106193323Sed return Changed; 107193323Sed} 108193323Sed 109193323Sed// operator<< - Used for debugging purposes. 110193323Sed// 111198090Srdivackystatic raw_ostream& operator<<(raw_ostream &O, 112218893Sdim const LowerSwitch::CaseVector &C) 113218893Sdim LLVM_ATTRIBUTE_USED; 114198090Srdivackystatic raw_ostream& operator<<(raw_ostream &O, 115198090Srdivacky const LowerSwitch::CaseVector &C) { 116193323Sed O << "["; 117193323Sed 118193323Sed for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 119193323Sed E = C.end(); B != E; ) { 120193323Sed O << *B->Low << " -" << *B->High; 121193323Sed if (++B != E) O << ", "; 122193323Sed } 123193323Sed 124193323Sed return O << "]"; 125193323Sed} 126193323Sed 127193323Sed// switchConvert - Convert the switch statement into a binary lookup of 128193323Sed// the case values. The function recursively builds this tree. 129193323Sed// 130193323SedBasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 131193323Sed Value* Val, BasicBlock* OrigBlock, 132193323Sed BasicBlock* Default) 133193323Sed{ 134193323Sed unsigned Size = End - Begin; 135193323Sed 136193323Sed if (Size == 1) 137193323Sed return newLeafBlock(*Begin, Val, OrigBlock, Default); 138193323Sed 139193323Sed unsigned Mid = Size / 2; 140193323Sed std::vector<CaseRange> LHS(Begin, Begin + Mid); 141202375Srdivacky DEBUG(dbgs() << "LHS: " << LHS << "\n"); 142193323Sed std::vector<CaseRange> RHS(Begin + Mid, End); 143202375Srdivacky DEBUG(dbgs() << "RHS: " << RHS << "\n"); 144193323Sed 145193323Sed CaseRange& Pivot = *(Begin + Mid); 146202375Srdivacky DEBUG(dbgs() << "Pivot ==> " 147193323Sed << cast<ConstantInt>(Pivot.Low)->getValue() << " -" 148193323Sed << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 149193323Sed 150193323Sed BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, 151193323Sed OrigBlock, Default); 152193323Sed BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val, 153193323Sed OrigBlock, Default); 154193323Sed 155193323Sed // Create a new node that checks if the value is < pivot. Go to the 156193323Sed // left branch if it is and right branch if not. 157193323Sed Function* F = OrigBlock->getParent(); 158198090Srdivacky BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 159193323Sed Function::iterator FI = OrigBlock; 160193323Sed F->getBasicBlockList().insert(++FI, NewNode); 161193323Sed 162198090Srdivacky ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 163198090Srdivacky Val, Pivot.Low, "Pivot"); 164193323Sed NewNode->getInstList().push_back(Comp); 165193323Sed BranchInst::Create(LBranch, RBranch, Comp, NewNode); 166193323Sed return NewNode; 167193323Sed} 168193323Sed 169193323Sed// newLeafBlock - Create a new leaf block for the binary lookup tree. It 170193323Sed// checks if the switch's value == the case's value. If not, then it 171193323Sed// jumps to the default branch. At this point in the tree, the value 172193323Sed// can't be another valid case value, so the jump to the "default" branch 173193323Sed// is warranted. 174193323Sed// 175193323SedBasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 176193323Sed BasicBlock* OrigBlock, 177193323Sed BasicBlock* Default) 178193323Sed{ 179193323Sed Function* F = OrigBlock->getParent(); 180198090Srdivacky BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 181193323Sed Function::iterator FI = OrigBlock; 182193323Sed F->getBasicBlockList().insert(++FI, NewLeaf); 183193323Sed 184193323Sed // Emit comparison 185193323Sed ICmpInst* Comp = NULL; 186193323Sed if (Leaf.Low == Leaf.High) { 187193323Sed // Make the seteq instruction... 188198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 189198090Srdivacky Leaf.Low, "SwitchLeaf"); 190193323Sed } else { 191193323Sed // Make range comparison 192193323Sed if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 193193323Sed // Val >= Min && Val <= Hi --> Val <= Hi 194198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 195198090Srdivacky "SwitchLeaf"); 196193323Sed } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 197193323Sed // Val >= 0 && Val <= Hi --> Val <=u Hi 198198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 199198090Srdivacky "SwitchLeaf"); 200193323Sed } else { 201193323Sed // Emit V-Lo <=u Hi-Lo 202193323Sed Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 203193323Sed Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 204193323Sed Val->getName()+".off", 205193323Sed NewLeaf); 206193323Sed Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 207198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 208198090Srdivacky "SwitchLeaf"); 209193323Sed } 210193323Sed } 211193323Sed 212193323Sed // Make the conditional branch... 213193323Sed BasicBlock* Succ = Leaf.BB; 214193323Sed BranchInst::Create(Succ, Default, Comp, NewLeaf); 215193323Sed 216193323Sed // If there were any PHI nodes in this successor, rewrite one entry 217193323Sed // from OrigBlock to come from NewLeaf. 218193323Sed for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 219193323Sed PHINode* PN = cast<PHINode>(I); 220193323Sed // Remove all but one incoming entries from the cluster 221193323Sed uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 222193323Sed cast<ConstantInt>(Leaf.Low)->getSExtValue(); 223193323Sed for (uint64_t j = 0; j < Range; ++j) { 224193323Sed PN->removeIncomingValue(OrigBlock); 225193323Sed } 226193323Sed 227193323Sed int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 228193323Sed assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 229193323Sed PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 230193323Sed } 231193323Sed 232193323Sed return NewLeaf; 233193323Sed} 234193323Sed 235193323Sed// Clusterify - Transform simple list of Cases into list of CaseRange's 236193323Sedunsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 237193323Sed unsigned numCmps = 0; 238193323Sed 239193323Sed // Start with "simple" cases 240193323Sed for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) 241193323Sed Cases.push_back(CaseRange(SI->getSuccessorValue(i), 242193323Sed SI->getSuccessorValue(i), 243193323Sed SI->getSuccessor(i))); 244193323Sed std::sort(Cases.begin(), Cases.end(), CaseCmp()); 245193323Sed 246193323Sed // Merge case into clusters 247193323Sed if (Cases.size()>=2) 248200581Srdivacky for (CaseItr I=Cases.begin(), J=llvm::next(Cases.begin()); J!=Cases.end(); ) { 249193323Sed int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 250193323Sed int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 251193323Sed BasicBlock* nextBB = J->BB; 252193323Sed BasicBlock* currentBB = I->BB; 253193323Sed 254193323Sed // If the two neighboring cases go to the same destination, merge them 255193323Sed // into a single case. 256193323Sed if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 257193323Sed I->High = J->High; 258193323Sed J = Cases.erase(J); 259193323Sed } else { 260193323Sed I = J++; 261193323Sed } 262193323Sed } 263193323Sed 264193323Sed for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 265193323Sed if (I->Low != I->High) 266193323Sed // A range counts double, since it requires two compares. 267193323Sed ++numCmps; 268193323Sed } 269193323Sed 270193323Sed return numCmps; 271193323Sed} 272193323Sed 273193323Sed// processSwitchInst - Replace the specified switch instruction with a sequence 274193323Sed// of chained if-then insts in a balanced binary search. 275193323Sed// 276193323Sedvoid LowerSwitch::processSwitchInst(SwitchInst *SI) { 277193323Sed BasicBlock *CurBlock = SI->getParent(); 278193323Sed BasicBlock *OrigBlock = CurBlock; 279193323Sed Function *F = CurBlock->getParent(); 280193323Sed Value *Val = SI->getOperand(0); // The value we are switching on... 281193323Sed BasicBlock* Default = SI->getDefaultDest(); 282193323Sed 283193323Sed // If there is only the default destination, don't bother with the code below. 284193323Sed if (SI->getNumOperands() == 2) { 285193323Sed BranchInst::Create(SI->getDefaultDest(), CurBlock); 286193323Sed CurBlock->getInstList().erase(SI); 287193323Sed return; 288193323Sed } 289193323Sed 290193323Sed // Create a new, empty default block so that the new hierarchy of 291193323Sed // if-then statements go to this and the PHI nodes are happy. 292198090Srdivacky BasicBlock* NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 293193323Sed F->getBasicBlockList().insert(Default, NewDefault); 294193323Sed 295193323Sed BranchInst::Create(Default, NewDefault); 296193323Sed 297193323Sed // If there is an entry in any PHI nodes for the default edge, make sure 298193323Sed // to update them as well. 299193323Sed for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 300193323Sed PHINode *PN = cast<PHINode>(I); 301193323Sed int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 302193323Sed assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 303193323Sed PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 304193323Sed } 305193323Sed 306193323Sed // Prepare cases vector. 307193323Sed CaseVector Cases; 308193323Sed unsigned numCmps = Clusterify(Cases, SI); 309193323Sed 310202375Srdivacky DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 311198090Srdivacky << ". Total compares: " << numCmps << "\n"); 312202375Srdivacky DEBUG(dbgs() << "Cases: " << Cases << "\n"); 313198090Srdivacky (void)numCmps; 314193323Sed 315193323Sed BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, 316193323Sed OrigBlock, NewDefault); 317193323Sed 318193323Sed // Branch to our shiny new if-then stuff... 319193323Sed BranchInst::Create(SwitchBlock, OrigBlock); 320193323Sed 321193323Sed // We are now done with the switch instruction, delete it. 322193323Sed CurBlock->getInstList().erase(SI); 323193323Sed} 324