LowerSwitch.cpp revision 193323
1193323Sed//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2193323Sed// 3193323Sed// The LLVM Compiler Infrastructure 4193323Sed// 5193323Sed// This file is distributed under the University of Illinois Open Source 6193323Sed// License. See LICENSE.TXT for details. 7193323Sed// 8193323Sed//===----------------------------------------------------------------------===// 9193323Sed// 10193323Sed// The LowerSwitch transformation rewrites switch instructions with a sequence 11193323Sed// of branches, which allows targets to get away with not implementing the 12193323Sed// switch instruction until it is convenient. 13193323Sed// 14193323Sed//===----------------------------------------------------------------------===// 15193323Sed 16193323Sed#include "llvm/Transforms/Scalar.h" 17193323Sed#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 18193323Sed#include "llvm/Constants.h" 19193323Sed#include "llvm/Function.h" 20193323Sed#include "llvm/Instructions.h" 21193323Sed#include "llvm/Pass.h" 22193323Sed#include "llvm/ADT/STLExtras.h" 23193323Sed#include "llvm/Support/Debug.h" 24193323Sed#include "llvm/Support/Compiler.h" 25193323Sed#include "llvm/Support/raw_ostream.h" 26193323Sed#include <algorithm> 27193323Sedusing namespace llvm; 28193323Sed 29193323Sednamespace { 30193323Sed /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 31193323Sed /// instructions. Note that this cannot be a BasicBlock pass because it 32193323Sed /// modifies the CFG! 33193323Sed class VISIBILITY_HIDDEN LowerSwitch : public FunctionPass { 34193323Sed public: 35193323Sed static char ID; // Pass identification, replacement for typeid 36193323Sed LowerSwitch() : FunctionPass(&ID) {} 37193323Sed 38193323Sed virtual bool runOnFunction(Function &F); 39193323Sed 40193323Sed virtual void getAnalysisUsage(AnalysisUsage &AU) const { 41193323Sed // This is a cluster of orthogonal Transforms 42193323Sed AU.addPreserved<UnifyFunctionExitNodes>(); 43193323Sed AU.addPreservedID(PromoteMemoryToRegisterID); 44193323Sed AU.addPreservedID(LowerInvokePassID); 45193323Sed AU.addPreservedID(LowerAllocationsID); 46193323Sed } 47193323Sed 48193323Sed struct CaseRange { 49193323Sed Constant* Low; 50193323Sed Constant* High; 51193323Sed BasicBlock* BB; 52193323Sed 53193323Sed CaseRange() : Low(0), High(0), BB(0) { } 54193323Sed CaseRange(Constant* low, Constant* high, BasicBlock* bb) : 55193323Sed Low(low), High(high), BB(bb) { } 56193323Sed }; 57193323Sed 58193323Sed typedef std::vector<CaseRange> CaseVector; 59193323Sed typedef std::vector<CaseRange>::iterator CaseItr; 60193323Sed private: 61193323Sed void processSwitchInst(SwitchInst *SI); 62193323Sed 63193323Sed BasicBlock* switchConvert(CaseItr Begin, CaseItr End, Value* Val, 64193323Sed BasicBlock* OrigBlock, BasicBlock* Default); 65193323Sed BasicBlock* newLeafBlock(CaseRange& Leaf, Value* Val, 66193323Sed BasicBlock* OrigBlock, BasicBlock* Default); 67193323Sed unsigned Clusterify(CaseVector& Cases, SwitchInst *SI); 68193323Sed }; 69193323Sed 70193323Sed /// The comparison function for sorting the switch case values in the vector. 71193323Sed /// WARNING: Case ranges should be disjoint! 72193323Sed struct CaseCmp { 73193323Sed bool operator () (const LowerSwitch::CaseRange& C1, 74193323Sed const LowerSwitch::CaseRange& C2) { 75193323Sed 76193323Sed const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 77193323Sed const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 78193323Sed return CI1->getValue().slt(CI2->getValue()); 79193323Sed } 80193323Sed }; 81193323Sed} 82193323Sed 83193323Sedchar LowerSwitch::ID = 0; 84193323Sedstatic RegisterPass<LowerSwitch> 85193323SedX("lowerswitch", "Lower SwitchInst's to branches"); 86193323Sed 87193323Sed// Publically exposed interface to pass... 88193323Sedconst PassInfo *const llvm::LowerSwitchID = &X; 89193323Sed// createLowerSwitchPass - Interface to this file... 90193323SedFunctionPass *llvm::createLowerSwitchPass() { 91193323Sed return new LowerSwitch(); 92193323Sed} 93193323Sed 94193323Sedbool LowerSwitch::runOnFunction(Function &F) { 95193323Sed bool Changed = false; 96193323Sed 97193323Sed for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 98193323Sed BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 99193323Sed 100193323Sed if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 101193323Sed Changed = true; 102193323Sed processSwitchInst(SI); 103193323Sed } 104193323Sed } 105193323Sed 106193323Sed return Changed; 107193323Sed} 108193323Sed 109193323Sed// operator<< - Used for debugging purposes. 110193323Sed// 111193323Sedstatic std::ostream& operator<<(std::ostream &O, 112193323Sed const LowerSwitch::CaseVector &C) { 113193323Sed O << "["; 114193323Sed 115193323Sed for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 116193323Sed E = C.end(); B != E; ) { 117193323Sed O << *B->Low << " -" << *B->High; 118193323Sed if (++B != E) O << ", "; 119193323Sed } 120193323Sed 121193323Sed return O << "]"; 122193323Sed} 123193323Sed 124193323Sedstatic OStream& operator<<(OStream &O, const LowerSwitch::CaseVector &C) { 125193323Sed if (O.stream()) *O.stream() << C; 126193323Sed return O; 127193323Sed} 128193323Sed 129193323Sed// switchConvert - Convert the switch statement into a binary lookup of 130193323Sed// the case values. The function recursively builds this tree. 131193323Sed// 132193323SedBasicBlock* LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 133193323Sed Value* Val, BasicBlock* OrigBlock, 134193323Sed BasicBlock* Default) 135193323Sed{ 136193323Sed unsigned Size = End - Begin; 137193323Sed 138193323Sed if (Size == 1) 139193323Sed return newLeafBlock(*Begin, Val, OrigBlock, Default); 140193323Sed 141193323Sed unsigned Mid = Size / 2; 142193323Sed std::vector<CaseRange> LHS(Begin, Begin + Mid); 143193323Sed DOUT << "LHS: " << LHS << "\n"; 144193323Sed std::vector<CaseRange> RHS(Begin + Mid, End); 145193323Sed DOUT << "RHS: " << RHS << "\n"; 146193323Sed 147193323Sed CaseRange& Pivot = *(Begin + Mid); 148193323Sed DEBUG(errs() << "Pivot ==> " 149193323Sed << cast<ConstantInt>(Pivot.Low)->getValue() << " -" 150193323Sed << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 151193323Sed 152193323Sed BasicBlock* LBranch = switchConvert(LHS.begin(), LHS.end(), Val, 153193323Sed OrigBlock, Default); 154193323Sed BasicBlock* RBranch = switchConvert(RHS.begin(), RHS.end(), Val, 155193323Sed OrigBlock, Default); 156193323Sed 157193323Sed // Create a new node that checks if the value is < pivot. Go to the 158193323Sed // left branch if it is and right branch if not. 159193323Sed Function* F = OrigBlock->getParent(); 160193323Sed BasicBlock* NewNode = BasicBlock::Create("NodeBlock"); 161193323Sed Function::iterator FI = OrigBlock; 162193323Sed F->getBasicBlockList().insert(++FI, NewNode); 163193323Sed 164193323Sed ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, Val, Pivot.Low, "Pivot"); 165193323Sed NewNode->getInstList().push_back(Comp); 166193323Sed BranchInst::Create(LBranch, RBranch, Comp, NewNode); 167193323Sed return NewNode; 168193323Sed} 169193323Sed 170193323Sed// newLeafBlock - Create a new leaf block for the binary lookup tree. It 171193323Sed// checks if the switch's value == the case's value. If not, then it 172193323Sed// jumps to the default branch. At this point in the tree, the value 173193323Sed// can't be another valid case value, so the jump to the "default" branch 174193323Sed// is warranted. 175193323Sed// 176193323SedBasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 177193323Sed BasicBlock* OrigBlock, 178193323Sed BasicBlock* Default) 179193323Sed{ 180193323Sed Function* F = OrigBlock->getParent(); 181193323Sed BasicBlock* NewLeaf = BasicBlock::Create("LeafBlock"); 182193323Sed Function::iterator FI = OrigBlock; 183193323Sed F->getBasicBlockList().insert(++FI, NewLeaf); 184193323Sed 185193323Sed // Emit comparison 186193323Sed ICmpInst* Comp = NULL; 187193323Sed if (Leaf.Low == Leaf.High) { 188193323Sed // Make the seteq instruction... 189193323Sed Comp = new ICmpInst(ICmpInst::ICMP_EQ, Val, Leaf.Low, 190193323Sed "SwitchLeaf", NewLeaf); 191193323Sed } else { 192193323Sed // Make range comparison 193193323Sed if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 194193323Sed // Val >= Min && Val <= Hi --> Val <= Hi 195193323Sed Comp = new ICmpInst(ICmpInst::ICMP_SLE, Val, Leaf.High, 196193323Sed "SwitchLeaf", NewLeaf); 197193323Sed } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 198193323Sed // Val >= 0 && Val <= Hi --> Val <=u Hi 199193323Sed Comp = new ICmpInst(ICmpInst::ICMP_ULE, Val, Leaf.High, 200193323Sed "SwitchLeaf", NewLeaf); 201193323Sed } else { 202193323Sed // Emit V-Lo <=u Hi-Lo 203193323Sed Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 204193323Sed Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 205193323Sed Val->getName()+".off", 206193323Sed NewLeaf); 207193323Sed Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 208193323Sed Comp = new ICmpInst(ICmpInst::ICMP_ULE, Add, UpperBound, 209193323Sed "SwitchLeaf", NewLeaf); 210193323Sed } 211193323Sed } 212193323Sed 213193323Sed // Make the conditional branch... 214193323Sed BasicBlock* Succ = Leaf.BB; 215193323Sed BranchInst::Create(Succ, Default, Comp, NewLeaf); 216193323Sed 217193323Sed // If there were any PHI nodes in this successor, rewrite one entry 218193323Sed // from OrigBlock to come from NewLeaf. 219193323Sed for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 220193323Sed PHINode* PN = cast<PHINode>(I); 221193323Sed // Remove all but one incoming entries from the cluster 222193323Sed uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 223193323Sed cast<ConstantInt>(Leaf.Low)->getSExtValue(); 224193323Sed for (uint64_t j = 0; j < Range; ++j) { 225193323Sed PN->removeIncomingValue(OrigBlock); 226193323Sed } 227193323Sed 228193323Sed int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 229193323Sed assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 230193323Sed PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 231193323Sed } 232193323Sed 233193323Sed return NewLeaf; 234193323Sed} 235193323Sed 236193323Sed// Clusterify - Transform simple list of Cases into list of CaseRange's 237193323Sedunsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 238193323Sed unsigned numCmps = 0; 239193323Sed 240193323Sed // Start with "simple" cases 241193323Sed for (unsigned i = 1; i < SI->getNumSuccessors(); ++i) 242193323Sed Cases.push_back(CaseRange(SI->getSuccessorValue(i), 243193323Sed SI->getSuccessorValue(i), 244193323Sed SI->getSuccessor(i))); 245193323Sed std::sort(Cases.begin(), Cases.end(), CaseCmp()); 246193323Sed 247193323Sed // Merge case into clusters 248193323Sed if (Cases.size()>=2) 249193323Sed for (CaseItr I=Cases.begin(), J=next(Cases.begin()); J!=Cases.end(); ) { 250193323Sed int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 251193323Sed int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 252193323Sed BasicBlock* nextBB = J->BB; 253193323Sed BasicBlock* currentBB = I->BB; 254193323Sed 255193323Sed // If the two neighboring cases go to the same destination, merge them 256193323Sed // into a single case. 257193323Sed if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 258193323Sed I->High = J->High; 259193323Sed J = Cases.erase(J); 260193323Sed } else { 261193323Sed I = J++; 262193323Sed } 263193323Sed } 264193323Sed 265193323Sed for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 266193323Sed if (I->Low != I->High) 267193323Sed // A range counts double, since it requires two compares. 268193323Sed ++numCmps; 269193323Sed } 270193323Sed 271193323Sed return numCmps; 272193323Sed} 273193323Sed 274193323Sed// processSwitchInst - Replace the specified switch instruction with a sequence 275193323Sed// of chained if-then insts in a balanced binary search. 276193323Sed// 277193323Sedvoid LowerSwitch::processSwitchInst(SwitchInst *SI) { 278193323Sed BasicBlock *CurBlock = SI->getParent(); 279193323Sed BasicBlock *OrigBlock = CurBlock; 280193323Sed Function *F = CurBlock->getParent(); 281193323Sed Value *Val = SI->getOperand(0); // The value we are switching on... 282193323Sed BasicBlock* Default = SI->getDefaultDest(); 283193323Sed 284193323Sed // If there is only the default destination, don't bother with the code below. 285193323Sed if (SI->getNumOperands() == 2) { 286193323Sed BranchInst::Create(SI->getDefaultDest(), CurBlock); 287193323Sed CurBlock->getInstList().erase(SI); 288193323Sed return; 289193323Sed } 290193323Sed 291193323Sed // Create a new, empty default block so that the new hierarchy of 292193323Sed // if-then statements go to this and the PHI nodes are happy. 293193323Sed BasicBlock* NewDefault = BasicBlock::Create("NewDefault"); 294193323Sed F->getBasicBlockList().insert(Default, NewDefault); 295193323Sed 296193323Sed BranchInst::Create(Default, NewDefault); 297193323Sed 298193323Sed // If there is an entry in any PHI nodes for the default edge, make sure 299193323Sed // to update them as well. 300193323Sed for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 301193323Sed PHINode *PN = cast<PHINode>(I); 302193323Sed int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 303193323Sed assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 304193323Sed PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 305193323Sed } 306193323Sed 307193323Sed // Prepare cases vector. 308193323Sed CaseVector Cases; 309193323Sed unsigned numCmps = Clusterify(Cases, SI); 310193323Sed 311193323Sed DOUT << "Clusterify finished. Total clusters: " << Cases.size() 312193323Sed << ". Total compares: " << numCmps << "\n"; 313193323Sed DOUT << "Cases: " << Cases << "\n"; 314193323Sed 315193323Sed BasicBlock* SwitchBlock = switchConvert(Cases.begin(), Cases.end(), Val, 316193323Sed OrigBlock, NewDefault); 317193323Sed 318193323Sed // Branch to our shiny new if-then stuff... 319193323Sed BranchInst::Create(SwitchBlock, OrigBlock); 320193323Sed 321193323Sed // We are now done with the switch instruction, delete it. 322193323Sed CurBlock->getInstList().erase(SI); 323193323Sed} 324