LowerSwitch.cpp revision 276479
1156230Smux//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2156230Smux// 3156230Smux// The LLVM Compiler Infrastructure 4156230Smux// 5156230Smux// This file is distributed under the University of Illinois Open Source 6156230Smux// License. See LICENSE.TXT for details. 7156230Smux// 8156230Smux//===----------------------------------------------------------------------===// 9156230Smux// 10156230Smux// The LowerSwitch transformation rewrites switch instructions with a sequence 11156230Smux// of branches, which allows targets to get away with not implementing the 12156230Smux// switch instruction until it is convenient. 13156230Smux// 14156230Smux//===----------------------------------------------------------------------===// 15156230Smux 16156230Smux#include "llvm/Transforms/Scalar.h" 17156230Smux#include "llvm/Transforms/Utils/BasicBlockUtils.h" 18156230Smux#include "llvm/ADT/STLExtras.h" 19156230Smux#include "llvm/IR/Constants.h" 20156230Smux#include "llvm/IR/Function.h" 21156230Smux#include "llvm/IR/Instructions.h" 22156230Smux#include "llvm/IR/LLVMContext.h" 23156230Smux#include "llvm/IR/CFG.h" 24156230Smux#include "llvm/Pass.h" 25156230Smux#include "llvm/Support/Compiler.h" 26156230Smux#include "llvm/Support/Debug.h" 27156230Smux#include "llvm/Support/raw_ostream.h" 28156230Smux#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 29156230Smux#include <algorithm> 30156230Smuxusing namespace llvm; 31156230Smux 32156230Smux#define DEBUG_TYPE "lower-switch" 33156230Smux 34156230Smuxnamespace { 35156230Smux /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 36156230Smux /// instructions. 37156230Smux class LowerSwitch : public FunctionPass { 38156230Smux public: 39156230Smux static char ID; // Pass identification, replacement for typeid 40156230Smux LowerSwitch() : FunctionPass(ID) { 41156230Smux initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 42156230Smux } 43156230Smux 44156230Smux bool runOnFunction(Function &F) override; 45156230Smux 46156230Smux void getAnalysisUsage(AnalysisUsage &AU) const override { 47186781Slulf // This is a cluster of orthogonal Transforms 48156230Smux AU.addPreserved<UnifyFunctionExitNodes>(); 49156230Smux AU.addPreserved("mem2reg"); 50156230Smux AU.addPreservedID(LowerInvokePassID); 51 } 52 53 struct CaseRange { 54 Constant* Low; 55 Constant* High; 56 BasicBlock* BB; 57 58 CaseRange(Constant *low = nullptr, Constant *high = nullptr, 59 BasicBlock *bb = nullptr) : 60 Low(low), High(high), BB(bb) { } 61 }; 62 63 typedef std::vector<CaseRange> CaseVector; 64 typedef std::vector<CaseRange>::iterator CaseItr; 65 private: 66 void processSwitchInst(SwitchInst *SI); 67 68 BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 69 ConstantInt *LowerBound, ConstantInt *UpperBound, 70 Value *Val, BasicBlock *Predecessor, 71 BasicBlock *OrigBlock, BasicBlock *Default); 72 BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, 73 BasicBlock *Default); 74 unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 75 }; 76 77 /// The comparison function for sorting the switch case values in the vector. 78 /// WARNING: Case ranges should be disjoint! 79 struct CaseCmp { 80 bool operator () (const LowerSwitch::CaseRange& C1, 81 const LowerSwitch::CaseRange& C2) { 82 83 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 84 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 85 return CI1->getValue().slt(CI2->getValue()); 86 } 87 }; 88} 89 90char LowerSwitch::ID = 0; 91INITIALIZE_PASS(LowerSwitch, "lowerswitch", 92 "Lower SwitchInst's to branches", false, false) 93 94// Publicly exposed interface to pass... 95char &llvm::LowerSwitchID = LowerSwitch::ID; 96// createLowerSwitchPass - Interface to this file... 97FunctionPass *llvm::createLowerSwitchPass() { 98 return new LowerSwitch(); 99} 100 101bool LowerSwitch::runOnFunction(Function &F) { 102 bool Changed = false; 103 104 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 105 BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 106 107 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 108 Changed = true; 109 processSwitchInst(SI); 110 } 111 } 112 113 return Changed; 114} 115 116// operator<< - Used for debugging purposes. 117// 118static raw_ostream& operator<<(raw_ostream &O, 119 const LowerSwitch::CaseVector &C) 120 LLVM_ATTRIBUTE_USED; 121static raw_ostream& operator<<(raw_ostream &O, 122 const LowerSwitch::CaseVector &C) { 123 O << "["; 124 125 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 126 E = C.end(); B != E; ) { 127 O << *B->Low << " -" << *B->High; 128 if (++B != E) O << ", "; 129 } 130 131 return O << "]"; 132} 133 134static void fixPhis(BasicBlock *Succ, 135 BasicBlock *OrigBlock, 136 BasicBlock *NewNode) { 137 for (BasicBlock::iterator I = Succ->begin(), 138 E = Succ->getFirstNonPHI(); 139 I != E; ++I) { 140 PHINode *PN = cast<PHINode>(I); 141 142 for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) { 143 if (PN->getIncomingBlock(I) == OrigBlock) 144 PN->setIncomingBlock(I, NewNode); 145 } 146 } 147} 148 149// switchConvert - Convert the switch statement into a binary lookup of 150// the case values. The function recursively builds this tree. 151// LowerBound and UpperBound are used to keep track of the bounds for Val 152// that have already been checked by a block emitted by one of the previous 153// calls to switchConvert in the call stack. 154BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 155 ConstantInt *LowerBound, 156 ConstantInt *UpperBound, Value *Val, 157 BasicBlock *Predecessor, 158 BasicBlock *OrigBlock, 159 BasicBlock *Default) { 160 unsigned Size = End - Begin; 161 162 if (Size == 1) { 163 // Check if the Case Range is perfectly squeezed in between 164 // already checked Upper and Lower bounds. If it is then we can avoid 165 // emitting the code that checks if the value actually falls in the range 166 // because the bounds already tell us so. 167 if (Begin->Low == LowerBound && Begin->High == UpperBound) { 168 fixPhis(Begin->BB, OrigBlock, Predecessor); 169 return Begin->BB; 170 } 171 return newLeafBlock(*Begin, Val, OrigBlock, Default); 172 } 173 174 unsigned Mid = Size / 2; 175 std::vector<CaseRange> LHS(Begin, Begin + Mid); 176 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 177 std::vector<CaseRange> RHS(Begin + Mid, End); 178 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 179 180 CaseRange &Pivot = *(Begin + Mid); 181 DEBUG(dbgs() << "Pivot ==> " 182 << cast<ConstantInt>(Pivot.Low)->getValue() 183 << " -" << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 184 185 // NewLowerBound here should never be the integer minimal value. 186 // This is because it is computed from a case range that is never 187 // the smallest, so there is always a case range that has at least 188 // a smaller value. 189 ConstantInt *NewLowerBound = cast<ConstantInt>(Pivot.Low); 190 ConstantInt *NewUpperBound; 191 192 // If we don't have a Default block then it means that we can never 193 // have a value outside of a case range, so set the UpperBound to the highest 194 // value in the LHS part of the case ranges. 195 if (Default != nullptr) { 196 // Because NewLowerBound is never the smallest representable integer 197 // it is safe here to subtract one. 198 NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 199 NewLowerBound->getValue() - 1); 200 } else { 201 CaseItr LastLHS = LHS.begin() + LHS.size() - 1; 202 NewUpperBound = cast<ConstantInt>(LastLHS->High); 203 } 204 205 DEBUG(dbgs() << "LHS Bounds ==> "; 206 if (LowerBound) { 207 dbgs() << cast<ConstantInt>(LowerBound)->getSExtValue(); 208 } else { 209 dbgs() << "NONE"; 210 } 211 dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; 212 dbgs() << "RHS Bounds ==> "; 213 dbgs() << NewLowerBound->getSExtValue() << " - "; 214 if (UpperBound) { 215 dbgs() << cast<ConstantInt>(UpperBound)->getSExtValue() << "\n"; 216 } else { 217 dbgs() << "NONE\n"; 218 }); 219 220 // Create a new node that checks if the value is < pivot. Go to the 221 // left branch if it is and right branch if not. 222 Function* F = OrigBlock->getParent(); 223 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 224 225 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 226 Val, Pivot.Low, "Pivot"); 227 228 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 229 NewUpperBound, Val, NewNode, OrigBlock, 230 Default); 231 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 232 UpperBound, Val, NewNode, OrigBlock, 233 Default); 234 235 Function::iterator FI = OrigBlock; 236 F->getBasicBlockList().insert(++FI, NewNode); 237 NewNode->getInstList().push_back(Comp); 238 239 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 240 return NewNode; 241} 242 243// newLeafBlock - Create a new leaf block for the binary lookup tree. It 244// checks if the switch's value == the case's value. If not, then it 245// jumps to the default branch. At this point in the tree, the value 246// can't be another valid case value, so the jump to the "default" branch 247// is warranted. 248// 249BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 250 BasicBlock* OrigBlock, 251 BasicBlock* Default) 252{ 253 Function* F = OrigBlock->getParent(); 254 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 255 Function::iterator FI = OrigBlock; 256 F->getBasicBlockList().insert(++FI, NewLeaf); 257 258 // Emit comparison 259 ICmpInst* Comp = nullptr; 260 if (Leaf.Low == Leaf.High) { 261 // Make the seteq instruction... 262 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 263 Leaf.Low, "SwitchLeaf"); 264 } else { 265 // Make range comparison 266 if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 267 // Val >= Min && Val <= Hi --> Val <= Hi 268 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 269 "SwitchLeaf"); 270 } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 271 // Val >= 0 && Val <= Hi --> Val <=u Hi 272 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 273 "SwitchLeaf"); 274 } else { 275 // Emit V-Lo <=u Hi-Lo 276 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 277 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 278 Val->getName()+".off", 279 NewLeaf); 280 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 281 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 282 "SwitchLeaf"); 283 } 284 } 285 286 // Make the conditional branch... 287 BasicBlock* Succ = Leaf.BB; 288 BranchInst::Create(Succ, Default, Comp, NewLeaf); 289 290 // If there were any PHI nodes in this successor, rewrite one entry 291 // from OrigBlock to come from NewLeaf. 292 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 293 PHINode* PN = cast<PHINode>(I); 294 // Remove all but one incoming entries from the cluster 295 uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 296 cast<ConstantInt>(Leaf.Low)->getSExtValue(); 297 for (uint64_t j = 0; j < Range; ++j) { 298 PN->removeIncomingValue(OrigBlock); 299 } 300 301 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 302 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 303 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 304 } 305 306 return NewLeaf; 307} 308 309// Clusterify - Transform simple list of Cases into list of CaseRange's 310unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 311 unsigned numCmps = 0; 312 313 // Start with "simple" cases 314 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) 315 Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), 316 i.getCaseSuccessor())); 317 318 std::sort(Cases.begin(), Cases.end(), CaseCmp()); 319 320 // Merge case into clusters 321 if (Cases.size()>=2) 322 for (CaseItr I = Cases.begin(), J = std::next(Cases.begin()); 323 J != Cases.end();) { 324 int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 325 int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 326 BasicBlock* nextBB = J->BB; 327 BasicBlock* currentBB = I->BB; 328 329 // If the two neighboring cases go to the same destination, merge them 330 // into a single case. 331 if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 332 I->High = J->High; 333 J = Cases.erase(J); 334 } else { 335 I = J++; 336 } 337 } 338 339 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 340 if (I->Low != I->High) 341 // A range counts double, since it requires two compares. 342 ++numCmps; 343 } 344 345 return numCmps; 346} 347 348// processSwitchInst - Replace the specified switch instruction with a sequence 349// of chained if-then insts in a balanced binary search. 350// 351void LowerSwitch::processSwitchInst(SwitchInst *SI) { 352 BasicBlock *CurBlock = SI->getParent(); 353 BasicBlock *OrigBlock = CurBlock; 354 Function *F = CurBlock->getParent(); 355 Value *Val = SI->getCondition(); // The value we are switching on... 356 BasicBlock* Default = SI->getDefaultDest(); 357 358 // If there is only the default destination, don't bother with the code below. 359 if (!SI->getNumCases()) { 360 BranchInst::Create(SI->getDefaultDest(), CurBlock); 361 CurBlock->getInstList().erase(SI); 362 return; 363 } 364 365 const bool DefaultIsUnreachable = 366 Default->size() == 1 && isa<UnreachableInst>(Default->getTerminator()); 367 // Create a new, empty default block so that the new hierarchy of 368 // if-then statements go to this and the PHI nodes are happy. 369 // if the default block is set as an unreachable we avoid creating one 370 // because will never be a valid target. 371 BasicBlock *NewDefault = nullptr; 372 if (!DefaultIsUnreachable) { 373 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 374 F->getBasicBlockList().insert(Default, NewDefault); 375 376 BranchInst::Create(Default, NewDefault); 377 } 378 // If there is an entry in any PHI nodes for the default edge, make sure 379 // to update them as well. 380 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 381 PHINode *PN = cast<PHINode>(I); 382 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 383 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 384 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 385 } 386 387 // Prepare cases vector. 388 CaseVector Cases; 389 unsigned numCmps = Clusterify(Cases, SI); 390 391 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 392 << ". Total compares: " << numCmps << "\n"); 393 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 394 (void)numCmps; 395 396 ConstantInt *UpperBound = nullptr; 397 ConstantInt *LowerBound = nullptr; 398 399 // Optimize the condition where Default is an unreachable block. In this case 400 // we can make the bounds tightly fitted around the case value ranges, 401 // because we know that the value passed to the switch should always be 402 // exactly one of the case values. 403 if (DefaultIsUnreachable) { 404 CaseItr LastCase = Cases.begin() + Cases.size() - 1; 405 UpperBound = cast<ConstantInt>(LastCase->High); 406 LowerBound = cast<ConstantInt>(Cases.begin()->Low); 407 } 408 BasicBlock *SwitchBlock = 409 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 410 OrigBlock, OrigBlock, NewDefault); 411 412 // Branch to our shiny new if-then stuff... 413 BranchInst::Create(SwitchBlock, OrigBlock); 414 415 // We are now done with the switch instruction, delete it. 416 CurBlock->getInstList().erase(SI); 417 418 pred_iterator PI = pred_begin(Default), E = pred_end(Default); 419 // If the Default block has no more predecessors just remove it 420 if (PI == E) { 421 DeleteDeadBlock(Default); 422 } 423} 424