LowerSwitch.cpp revision 280031
1//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2// 3// The LLVM Compiler Infrastructure 4// 5// This file is distributed under the University of Illinois Open Source 6// License. See LICENSE.TXT for details. 7// 8//===----------------------------------------------------------------------===// 9// 10// The LowerSwitch transformation rewrites switch instructions with a sequence 11// of branches, which allows targets to get away with not implementing the 12// switch instruction until it is convenient. 13// 14//===----------------------------------------------------------------------===// 15 16#include "llvm/Transforms/Scalar.h" 17#include "llvm/Transforms/Utils/BasicBlockUtils.h" 18#include "llvm/ADT/STLExtras.h" 19#include "llvm/IR/Constants.h" 20#include "llvm/IR/Function.h" 21#include "llvm/IR/Instructions.h" 22#include "llvm/IR/LLVMContext.h" 23#include "llvm/IR/CFG.h" 24#include "llvm/Pass.h" 25#include "llvm/Support/Compiler.h" 26#include "llvm/Support/Debug.h" 27#include "llvm/Support/raw_ostream.h" 28#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" 29#include <algorithm> 30using namespace llvm; 31 32#define DEBUG_TYPE "lower-switch" 33 34namespace { 35 /// LowerSwitch Pass - Replace all SwitchInst instructions with chained branch 36 /// instructions. 37 class LowerSwitch : public FunctionPass { 38 public: 39 static char ID; // Pass identification, replacement for typeid 40 LowerSwitch() : FunctionPass(ID) { 41 initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 42 } 43 44 bool runOnFunction(Function &F) override; 45 46 void getAnalysisUsage(AnalysisUsage &AU) const override { 47 // This is a cluster of orthogonal Transforms 48 AU.addPreserved<UnifyFunctionExitNodes>(); 49 AU.addPreservedID(LowerInvokePassID); 50 } 51 52 struct CaseRange { 53 Constant* Low; 54 Constant* High; 55 BasicBlock* BB; 56 57 CaseRange(Constant *low = nullptr, Constant *high = nullptr, 58 BasicBlock *bb = nullptr) : 59 Low(low), High(high), BB(bb) { } 60 }; 61 62 typedef std::vector<CaseRange> CaseVector; 63 typedef std::vector<CaseRange>::iterator CaseItr; 64 private: 65 void processSwitchInst(SwitchInst *SI); 66 67 BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 68 ConstantInt *LowerBound, ConstantInt *UpperBound, 69 Value *Val, BasicBlock *Predecessor, 70 BasicBlock *OrigBlock, BasicBlock *Default); 71 BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, BasicBlock *OrigBlock, 72 BasicBlock *Default); 73 unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 74 }; 75 76 /// The comparison function for sorting the switch case values in the vector. 77 /// WARNING: Case ranges should be disjoint! 78 struct CaseCmp { 79 bool operator () (const LowerSwitch::CaseRange& C1, 80 const LowerSwitch::CaseRange& C2) { 81 82 const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 83 const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 84 return CI1->getValue().slt(CI2->getValue()); 85 } 86 }; 87} 88 89char LowerSwitch::ID = 0; 90INITIALIZE_PASS(LowerSwitch, "lowerswitch", 91 "Lower SwitchInst's to branches", false, false) 92 93// Publicly exposed interface to pass... 94char &llvm::LowerSwitchID = LowerSwitch::ID; 95// createLowerSwitchPass - Interface to this file... 96FunctionPass *llvm::createLowerSwitchPass() { 97 return new LowerSwitch(); 98} 99 100bool LowerSwitch::runOnFunction(Function &F) { 101 bool Changed = false; 102 103 for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 104 BasicBlock *Cur = I++; // Advance over block so we don't traverse new blocks 105 106 if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 107 Changed = true; 108 processSwitchInst(SI); 109 } 110 } 111 112 return Changed; 113} 114 115// operator<< - Used for debugging purposes. 116// 117static raw_ostream& operator<<(raw_ostream &O, 118 const LowerSwitch::CaseVector &C) 119 LLVM_ATTRIBUTE_USED; 120static raw_ostream& operator<<(raw_ostream &O, 121 const LowerSwitch::CaseVector &C) { 122 O << "["; 123 124 for (LowerSwitch::CaseVector::const_iterator B = C.begin(), 125 E = C.end(); B != E; ) { 126 O << *B->Low << " -" << *B->High; 127 if (++B != E) O << ", "; 128 } 129 130 return O << "]"; 131} 132 133// \brief Update the first occurrence of the "switch statement" BB in the PHI 134// node with the "new" BB. The other occurrences will: 135// 136// 1) Be updated by subsequent calls to this function. Switch statements may 137// have more than one outcoming edge into the same BB if they all have the same 138// value. When the switch statement is converted these incoming edges are now 139// coming from multiple BBs. 140// 2) Removed if subsequent incoming values now share the same case, i.e., 141// multiple outcome edges are condensed into one. This is necessary to keep the 142// number of phi values equal to the number of branches to SuccBB. 143static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, 144 unsigned NumMergedCases) { 145 for (BasicBlock::iterator I = SuccBB->begin(), IE = SuccBB->getFirstNonPHI(); 146 I != IE; ++I) { 147 PHINode *PN = cast<PHINode>(I); 148 149 // Only update the first occurence. 150 unsigned Idx = 0, E = PN->getNumIncomingValues(); 151 unsigned LocalNumMergedCases = NumMergedCases; 152 for (; Idx != E; ++Idx) { 153 if (PN->getIncomingBlock(Idx) == OrigBB) { 154 PN->setIncomingBlock(Idx, NewBB); 155 break; 156 } 157 } 158 159 // Remove additional occurences coming from condensed cases and keep the 160 // number of incoming values equal to the number of branches to SuccBB. 161 for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) 162 if (PN->getIncomingBlock(Idx) == OrigBB) { 163 PN->removeIncomingValue(Idx); 164 LocalNumMergedCases--; 165 } 166 } 167} 168 169// switchConvert - Convert the switch statement into a binary lookup of 170// the case values. The function recursively builds this tree. 171// LowerBound and UpperBound are used to keep track of the bounds for Val 172// that have already been checked by a block emitted by one of the previous 173// calls to switchConvert in the call stack. 174BasicBlock *LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, 175 ConstantInt *LowerBound, 176 ConstantInt *UpperBound, Value *Val, 177 BasicBlock *Predecessor, 178 BasicBlock *OrigBlock, 179 BasicBlock *Default) { 180 unsigned Size = End - Begin; 181 182 if (Size == 1) { 183 // Check if the Case Range is perfectly squeezed in between 184 // already checked Upper and Lower bounds. If it is then we can avoid 185 // emitting the code that checks if the value actually falls in the range 186 // because the bounds already tell us so. 187 if (Begin->Low == LowerBound && Begin->High == UpperBound) { 188 unsigned NumMergedCases = 0; 189 if (LowerBound && UpperBound) 190 NumMergedCases = 191 UpperBound->getSExtValue() - LowerBound->getSExtValue(); 192 fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); 193 return Begin->BB; 194 } 195 return newLeafBlock(*Begin, Val, OrigBlock, Default); 196 } 197 198 unsigned Mid = Size / 2; 199 std::vector<CaseRange> LHS(Begin, Begin + Mid); 200 DEBUG(dbgs() << "LHS: " << LHS << "\n"); 201 std::vector<CaseRange> RHS(Begin + Mid, End); 202 DEBUG(dbgs() << "RHS: " << RHS << "\n"); 203 204 CaseRange &Pivot = *(Begin + Mid); 205 DEBUG(dbgs() << "Pivot ==> " 206 << cast<ConstantInt>(Pivot.Low)->getValue() 207 << " -" << cast<ConstantInt>(Pivot.High)->getValue() << "\n"); 208 209 // NewLowerBound here should never be the integer minimal value. 210 // This is because it is computed from a case range that is never 211 // the smallest, so there is always a case range that has at least 212 // a smaller value. 213 ConstantInt *NewLowerBound = cast<ConstantInt>(Pivot.Low); 214 ConstantInt *NewUpperBound; 215 216 // If we don't have a Default block then it means that we can never 217 // have a value outside of a case range, so set the UpperBound to the highest 218 // value in the LHS part of the case ranges. 219 if (Default != nullptr) { 220 // Because NewLowerBound is never the smallest representable integer 221 // it is safe here to subtract one. 222 NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 223 NewLowerBound->getValue() - 1); 224 } else { 225 CaseItr LastLHS = LHS.begin() + LHS.size() - 1; 226 NewUpperBound = cast<ConstantInt>(LastLHS->High); 227 } 228 229 DEBUG(dbgs() << "LHS Bounds ==> "; 230 if (LowerBound) { 231 dbgs() << cast<ConstantInt>(LowerBound)->getSExtValue(); 232 } else { 233 dbgs() << "NONE"; 234 } 235 dbgs() << " - " << NewUpperBound->getSExtValue() << "\n"; 236 dbgs() << "RHS Bounds ==> "; 237 dbgs() << NewLowerBound->getSExtValue() << " - "; 238 if (UpperBound) { 239 dbgs() << cast<ConstantInt>(UpperBound)->getSExtValue() << "\n"; 240 } else { 241 dbgs() << "NONE\n"; 242 }); 243 244 // Create a new node that checks if the value is < pivot. Go to the 245 // left branch if it is and right branch if not. 246 Function* F = OrigBlock->getParent(); 247 BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 248 249 ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 250 Val, Pivot.Low, "Pivot"); 251 252 BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 253 NewUpperBound, Val, NewNode, OrigBlock, 254 Default); 255 BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 256 UpperBound, Val, NewNode, OrigBlock, 257 Default); 258 259 Function::iterator FI = OrigBlock; 260 F->getBasicBlockList().insert(++FI, NewNode); 261 NewNode->getInstList().push_back(Comp); 262 263 BranchInst::Create(LBranch, RBranch, Comp, NewNode); 264 return NewNode; 265} 266 267// newLeafBlock - Create a new leaf block for the binary lookup tree. It 268// checks if the switch's value == the case's value. If not, then it 269// jumps to the default branch. At this point in the tree, the value 270// can't be another valid case value, so the jump to the "default" branch 271// is warranted. 272// 273BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, 274 BasicBlock* OrigBlock, 275 BasicBlock* Default) 276{ 277 Function* F = OrigBlock->getParent(); 278 BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 279 Function::iterator FI = OrigBlock; 280 F->getBasicBlockList().insert(++FI, NewLeaf); 281 282 // Emit comparison 283 ICmpInst* Comp = nullptr; 284 if (Leaf.Low == Leaf.High) { 285 // Make the seteq instruction... 286 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 287 Leaf.Low, "SwitchLeaf"); 288 } else { 289 // Make range comparison 290 if (cast<ConstantInt>(Leaf.Low)->isMinValue(true /*isSigned*/)) { 291 // Val >= Min && Val <= Hi --> Val <= Hi 292 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 293 "SwitchLeaf"); 294 } else if (cast<ConstantInt>(Leaf.Low)->isZero()) { 295 // Val >= 0 && Val <= Hi --> Val <=u Hi 296 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 297 "SwitchLeaf"); 298 } else { 299 // Emit V-Lo <=u Hi-Lo 300 Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 301 Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 302 Val->getName()+".off", 303 NewLeaf); 304 Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 305 Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 306 "SwitchLeaf"); 307 } 308 } 309 310 // Make the conditional branch... 311 BasicBlock* Succ = Leaf.BB; 312 BranchInst::Create(Succ, Default, Comp, NewLeaf); 313 314 // If there were any PHI nodes in this successor, rewrite one entry 315 // from OrigBlock to come from NewLeaf. 316 for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 317 PHINode* PN = cast<PHINode>(I); 318 // Remove all but one incoming entries from the cluster 319 uint64_t Range = cast<ConstantInt>(Leaf.High)->getSExtValue() - 320 cast<ConstantInt>(Leaf.Low)->getSExtValue(); 321 for (uint64_t j = 0; j < Range; ++j) { 322 PN->removeIncomingValue(OrigBlock); 323 } 324 325 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 326 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 327 PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 328 } 329 330 return NewLeaf; 331} 332 333// Clusterify - Transform simple list of Cases into list of CaseRange's 334unsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 335 unsigned numCmps = 0; 336 337 // Start with "simple" cases 338 for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); i != e; ++i) 339 Cases.push_back(CaseRange(i.getCaseValue(), i.getCaseValue(), 340 i.getCaseSuccessor())); 341 342 std::sort(Cases.begin(), Cases.end(), CaseCmp()); 343 344 // Merge case into clusters 345 if (Cases.size()>=2) 346 for (CaseItr I = Cases.begin(), J = std::next(Cases.begin()); 347 J != Cases.end();) { 348 int64_t nextValue = cast<ConstantInt>(J->Low)->getSExtValue(); 349 int64_t currentValue = cast<ConstantInt>(I->High)->getSExtValue(); 350 BasicBlock* nextBB = J->BB; 351 BasicBlock* currentBB = I->BB; 352 353 // If the two neighboring cases go to the same destination, merge them 354 // into a single case. 355 if ((nextValue-currentValue==1) && (currentBB == nextBB)) { 356 I->High = J->High; 357 J = Cases.erase(J); 358 } else { 359 I = J++; 360 } 361 } 362 363 for (CaseItr I=Cases.begin(), E=Cases.end(); I!=E; ++I, ++numCmps) { 364 if (I->Low != I->High) 365 // A range counts double, since it requires two compares. 366 ++numCmps; 367 } 368 369 return numCmps; 370} 371 372// processSwitchInst - Replace the specified switch instruction with a sequence 373// of chained if-then insts in a balanced binary search. 374// 375void LowerSwitch::processSwitchInst(SwitchInst *SI) { 376 BasicBlock *CurBlock = SI->getParent(); 377 BasicBlock *OrigBlock = CurBlock; 378 Function *F = CurBlock->getParent(); 379 Value *Val = SI->getCondition(); // The value we are switching on... 380 BasicBlock* Default = SI->getDefaultDest(); 381 382 // If there is only the default destination, don't bother with the code below. 383 if (!SI->getNumCases()) { 384 BranchInst::Create(SI->getDefaultDest(), CurBlock); 385 CurBlock->getInstList().erase(SI); 386 return; 387 } 388 389 const bool DefaultIsUnreachable = 390 Default->size() == 1 && isa<UnreachableInst>(Default->getTerminator()); 391 // Create a new, empty default block so that the new hierarchy of 392 // if-then statements go to this and the PHI nodes are happy. 393 // if the default block is set as an unreachable we avoid creating one 394 // because will never be a valid target. 395 BasicBlock *NewDefault = nullptr; 396 if (!DefaultIsUnreachable) { 397 NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 398 F->getBasicBlockList().insert(Default, NewDefault); 399 400 BranchInst::Create(Default, NewDefault); 401 } 402 // If there is an entry in any PHI nodes for the default edge, make sure 403 // to update them as well. 404 for (BasicBlock::iterator I = Default->begin(); isa<PHINode>(I); ++I) { 405 PHINode *PN = cast<PHINode>(I); 406 int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 407 assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 408 PN->setIncomingBlock((unsigned)BlockIdx, NewDefault); 409 } 410 411 // Prepare cases vector. 412 CaseVector Cases; 413 unsigned numCmps = Clusterify(Cases, SI); 414 415 DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 416 << ". Total compares: " << numCmps << "\n"); 417 DEBUG(dbgs() << "Cases: " << Cases << "\n"); 418 (void)numCmps; 419 420 ConstantInt *UpperBound = nullptr; 421 ConstantInt *LowerBound = nullptr; 422 423 // Optimize the condition where Default is an unreachable block. In this case 424 // we can make the bounds tightly fitted around the case value ranges, 425 // because we know that the value passed to the switch should always be 426 // exactly one of the case values. 427 if (DefaultIsUnreachable) { 428 CaseItr LastCase = Cases.begin() + Cases.size() - 1; 429 UpperBound = cast<ConstantInt>(LastCase->High); 430 LowerBound = cast<ConstantInt>(Cases.begin()->Low); 431 } 432 BasicBlock *SwitchBlock = 433 switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 434 OrigBlock, OrigBlock, NewDefault); 435 436 // Branch to our shiny new if-then stuff... 437 BranchInst::Create(SwitchBlock, OrigBlock); 438 439 // We are now done with the switch instruction, delete it. 440 CurBlock->getInstList().erase(SI); 441 442 pred_iterator PI = pred_begin(Default), E = pred_end(Default); 443 // If the Default block has no more predecessors just remove it 444 if (PI == E) { 445 DeleteDeadBlock(Default); 446 } 447} 448