SwitchLoweringUtils.h revision 360784
1//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 10#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 11 12#include "llvm/ADT/SmallVector.h" 13#include "llvm/CodeGen/SelectionDAGNodes.h" 14#include "llvm/CodeGen/TargetLowering.h" 15#include "llvm/IR/Constants.h" 16#include "llvm/Support/BranchProbability.h" 17 18namespace llvm { 19 20class FunctionLoweringInfo; 21class MachineBasicBlock; 22class BlockFrequencyInfo; 23 24namespace SwitchCG { 25 26enum CaseClusterKind { 27 /// A cluster of adjacent case labels with the same destination, or just one 28 /// case. 29 CC_Range, 30 /// A cluster of cases suitable for jump table lowering. 31 CC_JumpTable, 32 /// A cluster of cases suitable for bit test lowering. 33 CC_BitTests 34}; 35 36/// A cluster of case labels. 37struct CaseCluster { 38 CaseClusterKind Kind; 39 const ConstantInt *Low, *High; 40 union { 41 MachineBasicBlock *MBB; 42 unsigned JTCasesIndex; 43 unsigned BTCasesIndex; 44 }; 45 BranchProbability Prob; 46 47 static CaseCluster range(const ConstantInt *Low, const ConstantInt *High, 48 MachineBasicBlock *MBB, BranchProbability Prob) { 49 CaseCluster C; 50 C.Kind = CC_Range; 51 C.Low = Low; 52 C.High = High; 53 C.MBB = MBB; 54 C.Prob = Prob; 55 return C; 56 } 57 58 static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High, 59 unsigned JTCasesIndex, BranchProbability Prob) { 60 CaseCluster C; 61 C.Kind = CC_JumpTable; 62 C.Low = Low; 63 C.High = High; 64 C.JTCasesIndex = JTCasesIndex; 65 C.Prob = Prob; 66 return C; 67 } 68 69 static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High, 70 unsigned BTCasesIndex, BranchProbability Prob) { 71 CaseCluster C; 72 C.Kind = CC_BitTests; 73 C.Low = Low; 74 C.High = High; 75 C.BTCasesIndex = BTCasesIndex; 76 C.Prob = Prob; 77 return C; 78 } 79}; 80 81using CaseClusterVector = std::vector<CaseCluster>; 82using CaseClusterIt = CaseClusterVector::iterator; 83 84/// Sort Clusters and merge adjacent cases. 85void sortAndRangeify(CaseClusterVector &Clusters); 86 87struct CaseBits { 88 uint64_t Mask = 0; 89 MachineBasicBlock *BB = nullptr; 90 unsigned Bits = 0; 91 BranchProbability ExtraProb; 92 93 CaseBits() = default; 94 CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits, 95 BranchProbability Prob) 96 : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {} 97}; 98 99using CaseBitsVector = std::vector<CaseBits>; 100 101/// This structure is used to communicate between SelectionDAGBuilder and 102/// SDISel for the code generation of additional basic blocks needed by 103/// multi-case switch statements. 104struct CaseBlock { 105 // For the GISel interface. 106 struct PredInfoPair { 107 CmpInst::Predicate Pred; 108 // Set when no comparison should be emitted. 109 bool NoCmp; 110 }; 111 union { 112 // The condition code to use for the case block's setcc node. 113 // Besides the integer condition codes, this can also be SETTRUE, in which 114 // case no comparison gets emitted. 115 ISD::CondCode CC; 116 struct PredInfoPair PredInfo; 117 }; 118 119 // The LHS/MHS/RHS of the comparison to emit. 120 // Emit by default LHS op RHS. MHS is used for range comparisons: 121 // If MHS is not null: (LHS <= MHS) and (MHS <= RHS). 122 const Value *CmpLHS, *CmpMHS, *CmpRHS; 123 124 // The block to branch to if the setcc is true/false. 125 MachineBasicBlock *TrueBB, *FalseBB; 126 127 // The block into which to emit the code for the setcc and branches. 128 MachineBasicBlock *ThisBB; 129 130 /// The debug location of the instruction this CaseBlock was 131 /// produced from. 132 SDLoc DL; 133 DebugLoc DbgLoc; 134 135 // Branch weights. 136 BranchProbability TrueProb, FalseProb; 137 138 // Constructor for SelectionDAG. 139 CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs, 140 const Value *cmpmiddle, MachineBasicBlock *truebb, 141 MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl, 142 BranchProbability trueprob = BranchProbability::getUnknown(), 143 BranchProbability falseprob = BranchProbability::getUnknown()) 144 : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs), 145 TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl), 146 TrueProb(trueprob), FalseProb(falseprob) {} 147 148 // Constructor for GISel. 149 CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs, 150 const Value *cmprhs, const Value *cmpmiddle, 151 MachineBasicBlock *truebb, MachineBasicBlock *falsebb, 152 MachineBasicBlock *me, DebugLoc dl, 153 BranchProbability trueprob = BranchProbability::getUnknown(), 154 BranchProbability falseprob = BranchProbability::getUnknown()) 155 : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle), 156 CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me), 157 DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {} 158}; 159 160struct JumpTable { 161 /// The virtual register containing the index of the jump table entry 162 /// to jump to. 163 unsigned Reg; 164 /// The JumpTableIndex for this jump table in the function. 165 unsigned JTI; 166 /// The MBB into which to emit the code for the indirect jump. 167 MachineBasicBlock *MBB; 168 /// The MBB of the default bb, which is a successor of the range 169 /// check MBB. This is when updating PHI nodes in successors. 170 MachineBasicBlock *Default; 171 172 JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D) 173 : Reg(R), JTI(J), MBB(M), Default(D) {} 174}; 175struct JumpTableHeader { 176 APInt First; 177 APInt Last; 178 const Value *SValue; 179 MachineBasicBlock *HeaderBB; 180 bool Emitted; 181 bool OmitRangeCheck; 182 183 JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H, 184 bool E = false) 185 : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H), 186 Emitted(E), OmitRangeCheck(false) {} 187}; 188using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>; 189 190struct BitTestCase { 191 uint64_t Mask; 192 MachineBasicBlock *ThisBB; 193 MachineBasicBlock *TargetBB; 194 BranchProbability ExtraProb; 195 196 BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr, 197 BranchProbability Prob) 198 : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {} 199}; 200 201using BitTestInfo = SmallVector<BitTestCase, 3>; 202 203struct BitTestBlock { 204 APInt First; 205 APInt Range; 206 const Value *SValue; 207 unsigned Reg; 208 MVT RegVT; 209 bool Emitted; 210 bool ContiguousRange; 211 MachineBasicBlock *Parent; 212 MachineBasicBlock *Default; 213 BitTestInfo Cases; 214 BranchProbability Prob; 215 BranchProbability DefaultProb; 216 bool OmitRangeCheck; 217 218 BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E, 219 bool CR, MachineBasicBlock *P, MachineBasicBlock *D, 220 BitTestInfo C, BranchProbability Pr) 221 : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg), 222 RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D), 223 Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {} 224}; 225 226/// Return the range of values within a range. 227uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First, 228 unsigned Last); 229 230/// Return the number of cases within a range. 231uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases, 232 unsigned First, unsigned Last); 233 234struct SwitchWorkListItem { 235 MachineBasicBlock *MBB; 236 CaseClusterIt FirstCluster; 237 CaseClusterIt LastCluster; 238 const ConstantInt *GE; 239 const ConstantInt *LT; 240 BranchProbability DefaultProb; 241}; 242using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>; 243 244class SwitchLowering { 245public: 246 SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {} 247 248 void init(const TargetLowering &tli, const TargetMachine &tm, 249 const DataLayout &dl) { 250 TLI = &tli; 251 TM = &tm; 252 DL = &dl; 253 } 254 255 /// Vector of CaseBlock structures used to communicate SwitchInst code 256 /// generation information. 257 std::vector<CaseBlock> SwitchCases; 258 259 /// Vector of JumpTable structures used to communicate SwitchInst code 260 /// generation information. 261 std::vector<JumpTableBlock> JTCases; 262 263 /// Vector of BitTestBlock structures used to communicate SwitchInst code 264 /// generation information. 265 std::vector<BitTestBlock> BitTestCases; 266 267 void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI, 268 MachineBasicBlock *DefaultMBB, 269 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); 270 271 bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First, 272 unsigned Last, const SwitchInst *SI, 273 MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster); 274 275 276 void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI); 277 278 /// Build a bit test cluster from Clusters[First..Last]. Returns false if it 279 /// decides it's not a good idea. 280 bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last, 281 const SwitchInst *SI, CaseCluster &BTCluster); 282 283 virtual void addSuccessorWithProb( 284 MachineBasicBlock *Src, MachineBasicBlock *Dst, 285 BranchProbability Prob = BranchProbability::getUnknown()) = 0; 286 287 virtual ~SwitchLowering() = default; 288 289private: 290 const TargetLowering *TLI; 291 const TargetMachine *TM; 292 const DataLayout *DL; 293 FunctionLoweringInfo &FuncInfo; 294}; 295 296} // namespace SwitchCG 297} // namespace llvm 298 299#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H 300