1193323Sed//===- LowerSwitch.cpp - Eliminate Switch instructions --------------------===// 2193323Sed// 3353358Sdim// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4353358Sdim// See https://llvm.org/LICENSE.txt for license information. 5353358Sdim// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6193323Sed// 7193323Sed//===----------------------------------------------------------------------===// 8193323Sed// 9193323Sed// The LowerSwitch transformation rewrites switch instructions with a sequence 10193323Sed// of branches, which allows targets to get away with not implementing the 11193323Sed// switch instruction until it is convenient. 12193323Sed// 13193323Sed//===----------------------------------------------------------------------===// 14193323Sed 15327952Sdim#include "llvm/ADT/DenseMap.h" 16249423Sdim#include "llvm/ADT/STLExtras.h" 17327952Sdim#include "llvm/ADT/SmallPtrSet.h" 18327952Sdim#include "llvm/ADT/SmallVector.h" 19353358Sdim#include "llvm/Analysis/AssumptionCache.h" 20353358Sdim#include "llvm/Analysis/LazyValueInfo.h" 21353358Sdim#include "llvm/Analysis/ValueTracking.h" 22327952Sdim#include "llvm/IR/BasicBlock.h" 23288943Sdim#include "llvm/IR/CFG.h" 24353358Sdim#include "llvm/IR/ConstantRange.h" 25249423Sdim#include "llvm/IR/Constants.h" 26249423Sdim#include "llvm/IR/Function.h" 27327952Sdim#include "llvm/IR/InstrTypes.h" 28249423Sdim#include "llvm/IR/Instructions.h" 29327952Sdim#include "llvm/IR/Value.h" 30360784Sdim#include "llvm/InitializePasses.h" 31193323Sed#include "llvm/Pass.h" 32327952Sdim#include "llvm/Support/Casting.h" 33198892Srdivacky#include "llvm/Support/Compiler.h" 34193323Sed#include "llvm/Support/Debug.h" 35353358Sdim#include "llvm/Support/KnownBits.h" 36193323Sed#include "llvm/Support/raw_ostream.h" 37341825Sdim#include "llvm/Transforms/Utils.h" 38288943Sdim#include "llvm/Transforms/Utils/BasicBlockUtils.h" 39193323Sed#include <algorithm> 40327952Sdim#include <cassert> 41327952Sdim#include <cstdint> 42327952Sdim#include <iterator> 43327952Sdim#include <limits> 44327952Sdim#include <vector> 45327952Sdim 46193323Sedusing namespace llvm; 47193323Sed 48276479Sdim#define DEBUG_TYPE "lower-switch" 49276479Sdim 50193323Sednamespace { 51327952Sdim 52288943Sdim struct IntRange { 53288943Sdim int64_t Low, High; 54288943Sdim }; 55288943Sdim 56327952Sdim} // end anonymous namespace 57288943Sdim 58327952Sdim// Return true iff R is covered by Ranges. 59327952Sdimstatic bool IsInRanges(const IntRange &R, 60327952Sdim const std::vector<IntRange> &Ranges) { 61327952Sdim // Note: Ranges must be sorted, non-overlapping and non-adjacent. 62327952Sdim 63327952Sdim // Find the first range whose High field is >= R.High, 64327952Sdim // then check if the Low field is <= R.Low. If so, we 65327952Sdim // have a Range that covers R. 66353358Sdim auto I = llvm::lower_bound( 67353358Sdim Ranges, R, [](IntRange A, IntRange B) { return A.High < B.High; }); 68327952Sdim return I != Ranges.end() && I->Low <= R.Low; 69327952Sdim} 70327952Sdim 71327952Sdimnamespace { 72327952Sdim 73296417Sdim /// Replace all SwitchInst instructions with chained branch instructions. 74198892Srdivacky class LowerSwitch : public FunctionPass { 75193323Sed public: 76327952Sdim // Pass identification, replacement for typeid 77327952Sdim static char ID; 78327952Sdim 79218893Sdim LowerSwitch() : FunctionPass(ID) { 80218893Sdim initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); 81341825Sdim } 82193323Sed 83276479Sdim bool runOnFunction(Function &F) override; 84276479Sdim 85353358Sdim void getAnalysisUsage(AnalysisUsage &AU) const override { 86353358Sdim AU.addRequired<LazyValueInfoWrapperPass>(); 87353358Sdim } 88353358Sdim 89193323Sed struct CaseRange { 90288943Sdim ConstantInt* Low; 91288943Sdim ConstantInt* High; 92193323Sed BasicBlock* BB; 93193323Sed 94288943Sdim CaseRange(ConstantInt *low, ConstantInt *high, BasicBlock *bb) 95288943Sdim : Low(low), High(high), BB(bb) {} 96193323Sed }; 97193323Sed 98327952Sdim using CaseVector = std::vector<CaseRange>; 99327952Sdim using CaseItr = std::vector<CaseRange>::iterator; 100327952Sdim 101193323Sed private: 102353358Sdim void processSwitchInst(SwitchInst *SI, 103353358Sdim SmallPtrSetImpl<BasicBlock *> &DeleteList, 104353358Sdim AssumptionCache *AC, LazyValueInfo *LVI); 105193323Sed 106276479Sdim BasicBlock *switchConvert(CaseItr Begin, CaseItr End, 107276479Sdim ConstantInt *LowerBound, ConstantInt *UpperBound, 108276479Sdim Value *Val, BasicBlock *Predecessor, 109288943Sdim BasicBlock *OrigBlock, BasicBlock *Default, 110288943Sdim const std::vector<IntRange> &UnreachableRanges); 111353358Sdim BasicBlock *newLeafBlock(CaseRange &Leaf, Value *Val, 112353358Sdim ConstantInt *LowerBound, ConstantInt *UpperBound, 113353358Sdim BasicBlock *OrigBlock, BasicBlock *Default); 114276479Sdim unsigned Clusterify(CaseVector &Cases, SwitchInst *SI); 115193323Sed }; 116261991Sdim 117261991Sdim /// The comparison function for sorting the switch case values in the vector. 118261991Sdim /// WARNING: Case ranges should be disjoint! 119261991Sdim struct CaseCmp { 120327952Sdim bool operator()(const LowerSwitch::CaseRange& C1, 121327952Sdim const LowerSwitch::CaseRange& C2) { 122261991Sdim const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); 123261991Sdim const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); 124261991Sdim return CI1->getValue().slt(CI2->getValue()); 125261991Sdim } 126261991Sdim }; 127193323Sed 128327952Sdim} // end anonymous namespace 129327952Sdim 130193323Sedchar LowerSwitch::ID = 0; 131327952Sdim 132327952Sdim// Publicly exposed interface to pass... 133327952Sdimchar &llvm::LowerSwitchID = LowerSwitch::ID; 134327952Sdim 135353358SdimINITIALIZE_PASS_BEGIN(LowerSwitch, "lowerswitch", 136353358Sdim "Lower SwitchInst's to branches", false, false) 137353358SdimINITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) 138353358SdimINITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) 139353358SdimINITIALIZE_PASS_END(LowerSwitch, "lowerswitch", 140353358Sdim "Lower SwitchInst's to branches", false, false) 141193323Sed 142193323Sed// createLowerSwitchPass - Interface to this file... 143193323SedFunctionPass *llvm::createLowerSwitchPass() { 144193323Sed return new LowerSwitch(); 145193323Sed} 146193323Sed 147193323Sedbool LowerSwitch::runOnFunction(Function &F) { 148353358Sdim LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); 149353358Sdim auto *ACT = getAnalysisIfAvailable<AssumptionCacheTracker>(); 150353358Sdim AssumptionCache *AC = ACT ? &ACT->getAssumptionCache(F) : nullptr; 151353358Sdim // Prevent LazyValueInfo from using the DominatorTree as LowerSwitch does not 152353358Sdim // preserve it and it becomes stale (when available) pretty much immediately. 153353358Sdim // Currently the DominatorTree is only used by LowerSwitch indirectly via LVI 154353358Sdim // and computeKnownBits to refine isValidAssumeForContext's results. Given 155353358Sdim // that the latter can handle some of the simple cases w/o a DominatorTree, 156353358Sdim // it's easier to refrain from using the tree than to keep it up to date. 157353358Sdim LVI->disableDT(); 158353358Sdim 159193323Sed bool Changed = false; 160296417Sdim SmallPtrSet<BasicBlock*, 8> DeleteList; 161193323Sed 162193323Sed for (Function::iterator I = F.begin(), E = F.end(); I != E; ) { 163296417Sdim BasicBlock *Cur = &*I++; // Advance over block so we don't traverse new blocks 164193323Sed 165296417Sdim // If the block is a dead Default block that will be deleted later, don't 166296417Sdim // waste time processing it. 167296417Sdim if (DeleteList.count(Cur)) 168296417Sdim continue; 169296417Sdim 170193323Sed if (SwitchInst *SI = dyn_cast<SwitchInst>(Cur->getTerminator())) { 171193323Sed Changed = true; 172353358Sdim processSwitchInst(SI, DeleteList, AC, LVI); 173193323Sed } 174193323Sed } 175193323Sed 176296417Sdim for (BasicBlock* BB: DeleteList) { 177353358Sdim LVI->eraseBlock(BB); 178296417Sdim DeleteDeadBlock(BB); 179296417Sdim } 180296417Sdim 181193323Sed return Changed; 182193323Sed} 183193323Sed 184296417Sdim/// Used for debugging purposes. 185341825SdimLLVM_ATTRIBUTE_USED 186341825Sdimstatic raw_ostream &operator<<(raw_ostream &O, 187198090Srdivacky const LowerSwitch::CaseVector &C) { 188193323Sed O << "["; 189193323Sed 190353358Sdim for (LowerSwitch::CaseVector::const_iterator B = C.begin(), E = C.end(); 191353358Sdim B != E;) { 192353358Sdim O << "[" << B->Low->getValue() << ", " << B->High->getValue() << "]"; 193353358Sdim if (++B != E) 194353358Sdim O << ", "; 195193323Sed } 196193323Sed 197193323Sed return O << "]"; 198193323Sed} 199193323Sed 200341825Sdim/// Update the first occurrence of the "switch statement" BB in the PHI 201296417Sdim/// node with the "new" BB. The other occurrences will: 202296417Sdim/// 203296417Sdim/// 1) Be updated by subsequent calls to this function. Switch statements may 204296417Sdim/// have more than one outcoming edge into the same BB if they all have the same 205296417Sdim/// value. When the switch statement is converted these incoming edges are now 206296417Sdim/// coming from multiple BBs. 207296417Sdim/// 2) Removed if subsequent incoming values now share the same case, i.e., 208296417Sdim/// multiple outcome edges are condensed into one. This is necessary to keep the 209296417Sdim/// number of phi values equal to the number of branches to SuccBB. 210353358Sdimstatic void 211353358SdimfixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, 212353358Sdim const unsigned NumMergedCases = std::numeric_limits<unsigned>::max()) { 213296417Sdim for (BasicBlock::iterator I = SuccBB->begin(), 214296417Sdim IE = SuccBB->getFirstNonPHI()->getIterator(); 215280031Sdim I != IE; ++I) { 216276479Sdim PHINode *PN = cast<PHINode>(I); 217276479Sdim 218296417Sdim // Only update the first occurrence. 219280031Sdim unsigned Idx = 0, E = PN->getNumIncomingValues(); 220280031Sdim unsigned LocalNumMergedCases = NumMergedCases; 221280031Sdim for (; Idx != E; ++Idx) { 222280031Sdim if (PN->getIncomingBlock(Idx) == OrigBB) { 223280031Sdim PN->setIncomingBlock(Idx, NewBB); 224280031Sdim break; 225280031Sdim } 226276479Sdim } 227280031Sdim 228296417Sdim // Remove additional occurrences coming from condensed cases and keep the 229280031Sdim // number of incoming values equal to the number of branches to SuccBB. 230288943Sdim SmallVector<unsigned, 8> Indices; 231280031Sdim for (++Idx; LocalNumMergedCases > 0 && Idx < E; ++Idx) 232280031Sdim if (PN->getIncomingBlock(Idx) == OrigBB) { 233288943Sdim Indices.push_back(Idx); 234280031Sdim LocalNumMergedCases--; 235280031Sdim } 236288943Sdim // Remove incoming values in the reverse order to prevent invalidating 237288943Sdim // *successive* index. 238327952Sdim for (unsigned III : llvm::reverse(Indices)) 239309124Sdim PN->removeIncomingValue(III); 240276479Sdim } 241276479Sdim} 242276479Sdim 243296417Sdim/// Convert the switch statement into a binary lookup of the case values. 244296417Sdim/// The function recursively builds this tree. LowerBound and UpperBound are 245296417Sdim/// used to keep track of the bounds for Val that have already been checked by 246296417Sdim/// a block emitted by one of the previous calls to switchConvert in the call 247296417Sdim/// stack. 248288943SdimBasicBlock * 249288943SdimLowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, 250288943Sdim ConstantInt *UpperBound, Value *Val, 251288943Sdim BasicBlock *Predecessor, BasicBlock *OrigBlock, 252288943Sdim BasicBlock *Default, 253288943Sdim const std::vector<IntRange> &UnreachableRanges) { 254353358Sdim assert(LowerBound && UpperBound && "Bounds must be initialized"); 255193323Sed unsigned Size = End - Begin; 256193323Sed 257276479Sdim if (Size == 1) { 258276479Sdim // Check if the Case Range is perfectly squeezed in between 259276479Sdim // already checked Upper and Lower bounds. If it is then we can avoid 260276479Sdim // emitting the code that checks if the value actually falls in the range 261276479Sdim // because the bounds already tell us so. 262276479Sdim if (Begin->Low == LowerBound && Begin->High == UpperBound) { 263280031Sdim unsigned NumMergedCases = 0; 264353358Sdim NumMergedCases = UpperBound->getSExtValue() - LowerBound->getSExtValue(); 265280031Sdim fixPhis(Begin->BB, OrigBlock, Predecessor, NumMergedCases); 266276479Sdim return Begin->BB; 267276479Sdim } 268353358Sdim return newLeafBlock(*Begin, Val, LowerBound, UpperBound, OrigBlock, 269353358Sdim Default); 270276479Sdim } 271193323Sed 272193323Sed unsigned Mid = Size / 2; 273193323Sed std::vector<CaseRange> LHS(Begin, Begin + Mid); 274341825Sdim LLVM_DEBUG(dbgs() << "LHS: " << LHS << "\n"); 275193323Sed std::vector<CaseRange> RHS(Begin + Mid, End); 276341825Sdim LLVM_DEBUG(dbgs() << "RHS: " << RHS << "\n"); 277193323Sed 278276479Sdim CaseRange &Pivot = *(Begin + Mid); 279353358Sdim LLVM_DEBUG(dbgs() << "Pivot ==> [" << Pivot.Low->getValue() << ", " 280353358Sdim << Pivot.High->getValue() << "]\n"); 281193323Sed 282276479Sdim // NewLowerBound here should never be the integer minimal value. 283276479Sdim // This is because it is computed from a case range that is never 284276479Sdim // the smallest, so there is always a case range that has at least 285276479Sdim // a smaller value. 286288943Sdim ConstantInt *NewLowerBound = Pivot.Low; 287193323Sed 288288943Sdim // Because NewLowerBound is never the smallest representable integer 289288943Sdim // it is safe here to subtract one. 290288943Sdim ConstantInt *NewUpperBound = ConstantInt::get(NewLowerBound->getContext(), 291288943Sdim NewLowerBound->getValue() - 1); 292288943Sdim 293288943Sdim if (!UnreachableRanges.empty()) { 294288943Sdim // Check if the gap between LHS's highest and NewLowerBound is unreachable. 295288943Sdim int64_t GapLow = LHS.back().High->getSExtValue() + 1; 296288943Sdim int64_t GapHigh = NewLowerBound->getSExtValue() - 1; 297288943Sdim IntRange Gap = { GapLow, GapHigh }; 298288943Sdim if (GapHigh >= GapLow && IsInRanges(Gap, UnreachableRanges)) 299288943Sdim NewUpperBound = LHS.back().High; 300276479Sdim } 301276479Sdim 302353358Sdim LLVM_DEBUG(dbgs() << "LHS Bounds ==> [" << LowerBound->getSExtValue() << ", " 303353358Sdim << NewUpperBound->getSExtValue() << "]\n" 304353358Sdim << "RHS Bounds ==> [" << NewLowerBound->getSExtValue() 305353358Sdim << ", " << UpperBound->getSExtValue() << "]\n"); 306276479Sdim 307193323Sed // Create a new node that checks if the value is < pivot. Go to the 308193323Sed // left branch if it is and right branch if not. 309193323Sed Function* F = OrigBlock->getParent(); 310198090Srdivacky BasicBlock* NewNode = BasicBlock::Create(Val->getContext(), "NodeBlock"); 311193323Sed 312261991Sdim ICmpInst* Comp = new ICmpInst(ICmpInst::ICMP_SLT, 313198090Srdivacky Val, Pivot.Low, "Pivot"); 314276479Sdim 315276479Sdim BasicBlock *LBranch = switchConvert(LHS.begin(), LHS.end(), LowerBound, 316276479Sdim NewUpperBound, Val, NewNode, OrigBlock, 317288943Sdim Default, UnreachableRanges); 318276479Sdim BasicBlock *RBranch = switchConvert(RHS.begin(), RHS.end(), NewLowerBound, 319276479Sdim UpperBound, Val, NewNode, OrigBlock, 320288943Sdim Default, UnreachableRanges); 321276479Sdim 322296417Sdim F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewNode); 323193323Sed NewNode->getInstList().push_back(Comp); 324276479Sdim 325193323Sed BranchInst::Create(LBranch, RBranch, Comp, NewNode); 326193323Sed return NewNode; 327193323Sed} 328193323Sed 329296417Sdim/// Create a new leaf block for the binary lookup tree. It checks if the 330296417Sdim/// switch's value == the case's value. If not, then it jumps to the default 331296417Sdim/// branch. At this point in the tree, the value can't be another valid case 332296417Sdim/// value, so the jump to the "default" branch is warranted. 333353358SdimBasicBlock *LowerSwitch::newLeafBlock(CaseRange &Leaf, Value *Val, 334353358Sdim ConstantInt *LowerBound, 335353358Sdim ConstantInt *UpperBound, 336353358Sdim BasicBlock *OrigBlock, 337353358Sdim BasicBlock *Default) { 338193323Sed Function* F = OrigBlock->getParent(); 339198090Srdivacky BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); 340296417Sdim F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); 341193323Sed 342193323Sed // Emit comparison 343276479Sdim ICmpInst* Comp = nullptr; 344193323Sed if (Leaf.Low == Leaf.High) { 345193323Sed // Make the seteq instruction... 346198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_EQ, Val, 347198090Srdivacky Leaf.Low, "SwitchLeaf"); 348193323Sed } else { 349193323Sed // Make range comparison 350353358Sdim if (Leaf.Low == LowerBound) { 351193323Sed // Val >= Min && Val <= Hi --> Val <= Hi 352198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SLE, Val, Leaf.High, 353198090Srdivacky "SwitchLeaf"); 354353358Sdim } else if (Leaf.High == UpperBound) { 355353358Sdim // Val <= Max && Val >= Lo --> Val >= Lo 356353358Sdim Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_SGE, Val, Leaf.Low, 357353358Sdim "SwitchLeaf"); 358288943Sdim } else if (Leaf.Low->isZero()) { 359193323Sed // Val >= 0 && Val <= Hi --> Val <=u Hi 360198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Val, Leaf.High, 361341825Sdim "SwitchLeaf"); 362193323Sed } else { 363193323Sed // Emit V-Lo <=u Hi-Lo 364193323Sed Constant* NegLo = ConstantExpr::getNeg(Leaf.Low); 365193323Sed Instruction* Add = BinaryOperator::CreateAdd(Val, NegLo, 366193323Sed Val->getName()+".off", 367193323Sed NewLeaf); 368193323Sed Constant *UpperBound = ConstantExpr::getAdd(NegLo, Leaf.High); 369198090Srdivacky Comp = new ICmpInst(*NewLeaf, ICmpInst::ICMP_ULE, Add, UpperBound, 370198090Srdivacky "SwitchLeaf"); 371193323Sed } 372193323Sed } 373193323Sed 374193323Sed // Make the conditional branch... 375193323Sed BasicBlock* Succ = Leaf.BB; 376193323Sed BranchInst::Create(Succ, Default, Comp, NewLeaf); 377193323Sed 378193323Sed // If there were any PHI nodes in this successor, rewrite one entry 379193323Sed // from OrigBlock to come from NewLeaf. 380193323Sed for (BasicBlock::iterator I = Succ->begin(); isa<PHINode>(I); ++I) { 381193323Sed PHINode* PN = cast<PHINode>(I); 382193323Sed // Remove all but one incoming entries from the cluster 383288943Sdim uint64_t Range = Leaf.High->getSExtValue() - 384288943Sdim Leaf.Low->getSExtValue(); 385193323Sed for (uint64_t j = 0; j < Range; ++j) { 386193323Sed PN->removeIncomingValue(OrigBlock); 387193323Sed } 388341825Sdim 389193323Sed int BlockIdx = PN->getBasicBlockIndex(OrigBlock); 390193323Sed assert(BlockIdx != -1 && "Switch didn't go to this successor??"); 391193323Sed PN->setIncomingBlock((unsigned)BlockIdx, NewLeaf); 392193323Sed } 393193323Sed 394193323Sed return NewLeaf; 395193323Sed} 396193323Sed 397353358Sdim/// Transform simple list of \p SI's cases into list of CaseRange's \p Cases. 398353358Sdim/// \post \p Cases wouldn't contain references to \p SI's default BB. 399353358Sdim/// \returns Number of \p SI's cases that do not reference \p SI's default BB. 400193323Sedunsigned LowerSwitch::Clusterify(CaseVector& Cases, SwitchInst *SI) { 401353358Sdim unsigned NumSimpleCases = 0; 402193323Sed 403193323Sed // Start with "simple" cases 404353358Sdim for (auto Case : SI->cases()) { 405353358Sdim if (Case.getCaseSuccessor() == SI->getDefaultDest()) 406353358Sdim continue; 407321369Sdim Cases.push_back(CaseRange(Case.getCaseValue(), Case.getCaseValue(), 408321369Sdim Case.getCaseSuccessor())); 409353358Sdim ++NumSimpleCases; 410353358Sdim } 411321369Sdim 412344779Sdim llvm::sort(Cases, CaseCmp()); 413261991Sdim 414261991Sdim // Merge case into clusters 415288943Sdim if (Cases.size() >= 2) { 416288943Sdim CaseItr I = Cases.begin(); 417288943Sdim for (CaseItr J = std::next(I), E = Cases.end(); J != E; ++J) { 418288943Sdim int64_t nextValue = J->Low->getSExtValue(); 419288943Sdim int64_t currentValue = I->High->getSExtValue(); 420261991Sdim BasicBlock* nextBB = J->BB; 421261991Sdim BasicBlock* currentBB = I->BB; 422261991Sdim 423261991Sdim // If the two neighboring cases go to the same destination, merge them 424261991Sdim // into a single case. 425288943Sdim assert(nextValue > currentValue && "Cases should be strictly ascending"); 426288943Sdim if ((nextValue == currentValue + 1) && (currentBB == nextBB)) { 427261991Sdim I->High = J->High; 428288943Sdim // FIXME: Combine branch weights. 429288943Sdim } else if (++I != J) { 430288943Sdim *I = *J; 431261991Sdim } 432261991Sdim } 433288943Sdim Cases.erase(std::next(I), Cases.end()); 434288943Sdim } 435261991Sdim 436353358Sdim return NumSimpleCases; 437193323Sed} 438193323Sed 439296417Sdim/// Replace the specified switch instruction with a sequence of chained if-then 440296417Sdim/// insts in a balanced binary search. 441296417Sdimvoid LowerSwitch::processSwitchInst(SwitchInst *SI, 442353358Sdim SmallPtrSetImpl<BasicBlock *> &DeleteList, 443353358Sdim AssumptionCache *AC, LazyValueInfo *LVI) { 444353358Sdim BasicBlock *OrigBlock = SI->getParent(); 445353358Sdim Function *F = OrigBlock->getParent(); 446226633Sdim Value *Val = SI->getCondition(); // The value we are switching on... 447193323Sed BasicBlock* Default = SI->getDefaultDest(); 448193323Sed 449321369Sdim // Don't handle unreachable blocks. If there are successors with phis, this 450321369Sdim // would leave them behind with missing predecessors. 451353358Sdim if ((OrigBlock != &F->getEntryBlock() && pred_empty(OrigBlock)) || 452353358Sdim OrigBlock->getSinglePredecessor() == OrigBlock) { 453353358Sdim DeleteList.insert(OrigBlock); 454321369Sdim return; 455321369Sdim } 456321369Sdim 457353358Sdim // Prepare cases vector. 458353358Sdim CaseVector Cases; 459353358Sdim const unsigned NumSimpleCases = Clusterify(Cases, SI); 460353358Sdim LLVM_DEBUG(dbgs() << "Clusterify finished. Total clusters: " << Cases.size() 461353358Sdim << ". Total non-default cases: " << NumSimpleCases 462353358Sdim << "\nCase clusters: " << Cases << "\n"); 463353358Sdim 464288943Sdim // If there is only the default destination, just branch. 465353358Sdim if (Cases.empty()) { 466353358Sdim BranchInst::Create(Default, OrigBlock); 467353358Sdim // Remove all the references from Default's PHIs to OrigBlock, but one. 468353358Sdim fixPhis(Default, OrigBlock, OrigBlock); 469288943Sdim SI->eraseFromParent(); 470193323Sed return; 471193323Sed } 472193323Sed 473288943Sdim ConstantInt *LowerBound = nullptr; 474288943Sdim ConstantInt *UpperBound = nullptr; 475353358Sdim bool DefaultIsUnreachableFromSwitch = false; 476288943Sdim 477288943Sdim if (isa<UnreachableInst>(Default->getFirstNonPHIOrDbg())) { 478296417Sdim // Make the bounds tightly fitted around the case value range, because we 479288943Sdim // know that the value passed to the switch must be exactly one of the case 480288943Sdim // values. 481288943Sdim LowerBound = Cases.front().Low; 482288943Sdim UpperBound = Cases.back().High; 483353358Sdim DefaultIsUnreachableFromSwitch = true; 484353358Sdim } else { 485353358Sdim // Constraining the range of the value being switched over helps eliminating 486353358Sdim // unreachable BBs and minimizing the number of `add` instructions 487353358Sdim // newLeafBlock ends up emitting. Running CorrelatedValuePropagation after 488353358Sdim // LowerSwitch isn't as good, and also much more expensive in terms of 489353358Sdim // compile time for the following reasons: 490353358Sdim // 1. it processes many kinds of instructions, not just switches; 491353358Sdim // 2. even if limited to icmp instructions only, it will have to process 492353358Sdim // roughly C icmp's per switch, where C is the number of cases in the 493353358Sdim // switch, while LowerSwitch only needs to call LVI once per switch. 494353358Sdim const DataLayout &DL = F->getParent()->getDataLayout(); 495353358Sdim KnownBits Known = computeKnownBits(Val, DL, /*Depth=*/0, AC, SI); 496353358Sdim // TODO Shouldn't this create a signed range? 497353358Sdim ConstantRange KnownBitsRange = 498353358Sdim ConstantRange::fromKnownBits(Known, /*IsSigned=*/false); 499353358Sdim const ConstantRange LVIRange = LVI->getConstantRange(Val, OrigBlock, SI); 500353358Sdim ConstantRange ValRange = KnownBitsRange.intersectWith(LVIRange); 501353358Sdim // We delegate removal of unreachable non-default cases to other passes. In 502353358Sdim // the unlikely event that some of them survived, we just conservatively 503353358Sdim // maintain the invariant that all the cases lie between the bounds. This 504353358Sdim // may, however, still render the default case effectively unreachable. 505353358Sdim APInt Low = Cases.front().Low->getValue(); 506353358Sdim APInt High = Cases.back().High->getValue(); 507353358Sdim APInt Min = APIntOps::smin(ValRange.getSignedMin(), Low); 508353358Sdim APInt Max = APIntOps::smax(ValRange.getSignedMax(), High); 509288943Sdim 510353358Sdim LowerBound = ConstantInt::get(SI->getContext(), Min); 511353358Sdim UpperBound = ConstantInt::get(SI->getContext(), Max); 512353358Sdim DefaultIsUnreachableFromSwitch = (Min + (NumSimpleCases - 1) == Max); 513353358Sdim } 514353358Sdim 515353358Sdim std::vector<IntRange> UnreachableRanges; 516353358Sdim 517353358Sdim if (DefaultIsUnreachableFromSwitch) { 518288943Sdim DenseMap<BasicBlock *, unsigned> Popularity; 519288943Sdim unsigned MaxPop = 0; 520288943Sdim BasicBlock *PopSucc = nullptr; 521288943Sdim 522327952Sdim IntRange R = {std::numeric_limits<int64_t>::min(), 523327952Sdim std::numeric_limits<int64_t>::max()}; 524288943Sdim UnreachableRanges.push_back(R); 525288943Sdim for (const auto &I : Cases) { 526288943Sdim int64_t Low = I.Low->getSExtValue(); 527288943Sdim int64_t High = I.High->getSExtValue(); 528288943Sdim 529288943Sdim IntRange &LastRange = UnreachableRanges.back(); 530288943Sdim if (LastRange.Low == Low) { 531288943Sdim // There is nothing left of the previous range. 532288943Sdim UnreachableRanges.pop_back(); 533288943Sdim } else { 534288943Sdim // Terminate the previous range. 535288943Sdim assert(Low > LastRange.Low); 536288943Sdim LastRange.High = Low - 1; 537288943Sdim } 538327952Sdim if (High != std::numeric_limits<int64_t>::max()) { 539327952Sdim IntRange R = { High + 1, std::numeric_limits<int64_t>::max() }; 540288943Sdim UnreachableRanges.push_back(R); 541288943Sdim } 542288943Sdim 543288943Sdim // Count popularity. 544288943Sdim int64_t N = High - Low + 1; 545288943Sdim unsigned &Pop = Popularity[I.BB]; 546288943Sdim if ((Pop += N) > MaxPop) { 547288943Sdim MaxPop = Pop; 548288943Sdim PopSucc = I.BB; 549288943Sdim } 550288943Sdim } 551288943Sdim#ifndef NDEBUG 552288943Sdim /* UnreachableRanges should be sorted and the ranges non-adjacent. */ 553288943Sdim for (auto I = UnreachableRanges.begin(), E = UnreachableRanges.end(); 554288943Sdim I != E; ++I) { 555288943Sdim assert(I->Low <= I->High); 556288943Sdim auto Next = I + 1; 557288943Sdim if (Next != E) { 558288943Sdim assert(Next->Low > I->High); 559288943Sdim } 560288943Sdim } 561288943Sdim#endif 562288943Sdim 563341825Sdim // As the default block in the switch is unreachable, update the PHI nodes 564353358Sdim // (remove all of the references to the default block) to reflect this. 565353358Sdim const unsigned NumDefaultEdges = SI->getNumCases() + 1 - NumSimpleCases; 566353358Sdim for (unsigned I = 0; I < NumDefaultEdges; ++I) 567353358Sdim Default->removePredecessor(OrigBlock); 568341825Sdim 569288943Sdim // Use the most popular block as the new default, reducing the number of 570288943Sdim // cases. 571288943Sdim assert(MaxPop > 0 && PopSucc); 572288943Sdim Default = PopSucc; 573314564Sdim Cases.erase( 574327952Sdim llvm::remove_if( 575327952Sdim Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), 576314564Sdim Cases.end()); 577288943Sdim 578288943Sdim // If there are no cases left, just branch. 579288943Sdim if (Cases.empty()) { 580353358Sdim BranchInst::Create(Default, OrigBlock); 581288943Sdim SI->eraseFromParent(); 582341825Sdim // As all the cases have been replaced with a single branch, only keep 583341825Sdim // one entry in the PHI nodes. 584341825Sdim for (unsigned I = 0 ; I < (MaxPop - 1) ; ++I) 585341825Sdim PopSucc->removePredecessor(OrigBlock); 586288943Sdim return; 587288943Sdim } 588353358Sdim 589353358Sdim // If the condition was a PHI node with the switch block as a predecessor 590353358Sdim // removing predecessors may have caused the condition to be erased. 591353358Sdim // Getting the condition value again here protects against that. 592353358Sdim Val = SI->getCondition(); 593288943Sdim } 594288943Sdim 595193323Sed // Create a new, empty default block so that the new hierarchy of 596193323Sed // if-then statements go to this and the PHI nodes are happy. 597288943Sdim BasicBlock *NewDefault = BasicBlock::Create(SI->getContext(), "NewDefault"); 598296417Sdim F->getBasicBlockList().insert(Default->getIterator(), NewDefault); 599288943Sdim BranchInst::Create(Default, NewDefault); 600193323Sed 601276479Sdim BasicBlock *SwitchBlock = 602276479Sdim switchConvert(Cases.begin(), Cases.end(), LowerBound, UpperBound, Val, 603288943Sdim OrigBlock, OrigBlock, NewDefault, UnreachableRanges); 604276479Sdim 605341825Sdim // If there are entries in any PHI nodes for the default edge, make sure 606341825Sdim // to update them as well. 607353358Sdim fixPhis(Default, OrigBlock, NewDefault); 608341825Sdim 609193323Sed // Branch to our shiny new if-then stuff... 610193323Sed BranchInst::Create(SwitchBlock, OrigBlock); 611193323Sed 612193323Sed // We are now done with the switch instruction, delete it. 613288943Sdim BasicBlock *OldDefault = SI->getDefaultDest(); 614353358Sdim OrigBlock->getInstList().erase(SI); 615276479Sdim 616296417Sdim // If the Default block has no more predecessors just add it to DeleteList. 617288943Sdim if (pred_begin(OldDefault) == pred_end(OldDefault)) 618296417Sdim DeleteList.insert(OldDefault); 619193323Sed} 620