SwitchLoweringUtils.h revision 360784
1//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/CodeGen/SelectionDAGNodes.h"
14#include "llvm/CodeGen/TargetLowering.h"
15#include "llvm/IR/Constants.h"
16#include "llvm/Support/BranchProbability.h"
17
18namespace llvm {
19
20class FunctionLoweringInfo;
21class MachineBasicBlock;
22class BlockFrequencyInfo;
23
24namespace SwitchCG {
25
26enum CaseClusterKind {
27  /// A cluster of adjacent case labels with the same destination, or just one
28  /// case.
29  CC_Range,
30  /// A cluster of cases suitable for jump table lowering.
31  CC_JumpTable,
32  /// A cluster of cases suitable for bit test lowering.
33  CC_BitTests
34};
35
36/// A cluster of case labels.
37struct CaseCluster {
38  CaseClusterKind Kind;
39  const ConstantInt *Low, *High;
40  union {
41    MachineBasicBlock *MBB;
42    unsigned JTCasesIndex;
43    unsigned BTCasesIndex;
44  };
45  BranchProbability Prob;
46
47  static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
48                           MachineBasicBlock *MBB, BranchProbability Prob) {
49    CaseCluster C;
50    C.Kind = CC_Range;
51    C.Low = Low;
52    C.High = High;
53    C.MBB = MBB;
54    C.Prob = Prob;
55    return C;
56  }
57
58  static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
59                               unsigned JTCasesIndex, BranchProbability Prob) {
60    CaseCluster C;
61    C.Kind = CC_JumpTable;
62    C.Low = Low;
63    C.High = High;
64    C.JTCasesIndex = JTCasesIndex;
65    C.Prob = Prob;
66    return C;
67  }
68
69  static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
70                              unsigned BTCasesIndex, BranchProbability Prob) {
71    CaseCluster C;
72    C.Kind = CC_BitTests;
73    C.Low = Low;
74    C.High = High;
75    C.BTCasesIndex = BTCasesIndex;
76    C.Prob = Prob;
77    return C;
78  }
79};
80
81using CaseClusterVector = std::vector<CaseCluster>;
82using CaseClusterIt = CaseClusterVector::iterator;
83
84/// Sort Clusters and merge adjacent cases.
85void sortAndRangeify(CaseClusterVector &Clusters);
86
87struct CaseBits {
88  uint64_t Mask = 0;
89  MachineBasicBlock *BB = nullptr;
90  unsigned Bits = 0;
91  BranchProbability ExtraProb;
92
93  CaseBits() = default;
94  CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
95           BranchProbability Prob)
96      : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
97};
98
99using CaseBitsVector = std::vector<CaseBits>;
100
101/// This structure is used to communicate between SelectionDAGBuilder and
102/// SDISel for the code generation of additional basic blocks needed by
103/// multi-case switch statements.
104struct CaseBlock {
105  // For the GISel interface.
106  struct PredInfoPair {
107    CmpInst::Predicate Pred;
108    // Set when no comparison should be emitted.
109    bool NoCmp;
110  };
111  union {
112    // The condition code to use for the case block's setcc node.
113    // Besides the integer condition codes, this can also be SETTRUE, in which
114    // case no comparison gets emitted.
115    ISD::CondCode CC;
116    struct PredInfoPair PredInfo;
117  };
118
119  // The LHS/MHS/RHS of the comparison to emit.
120  // Emit by default LHS op RHS. MHS is used for range comparisons:
121  // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
122  const Value *CmpLHS, *CmpMHS, *CmpRHS;
123
124  // The block to branch to if the setcc is true/false.
125  MachineBasicBlock *TrueBB, *FalseBB;
126
127  // The block into which to emit the code for the setcc and branches.
128  MachineBasicBlock *ThisBB;
129
130  /// The debug location of the instruction this CaseBlock was
131  /// produced from.
132  SDLoc DL;
133  DebugLoc DbgLoc;
134
135  // Branch weights.
136  BranchProbability TrueProb, FalseProb;
137
138  // Constructor for SelectionDAG.
139  CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
140            const Value *cmpmiddle, MachineBasicBlock *truebb,
141            MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
142            BranchProbability trueprob = BranchProbability::getUnknown(),
143            BranchProbability falseprob = BranchProbability::getUnknown())
144      : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
145        TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
146        TrueProb(trueprob), FalseProb(falseprob) {}
147
148  // Constructor for GISel.
149  CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
150            const Value *cmprhs, const Value *cmpmiddle,
151            MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
152            MachineBasicBlock *me, DebugLoc dl,
153            BranchProbability trueprob = BranchProbability::getUnknown(),
154            BranchProbability falseprob = BranchProbability::getUnknown())
155      : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
156        CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
157        DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
158};
159
160struct JumpTable {
161  /// The virtual register containing the index of the jump table entry
162  /// to jump to.
163  unsigned Reg;
164  /// The JumpTableIndex for this jump table in the function.
165  unsigned JTI;
166  /// The MBB into which to emit the code for the indirect jump.
167  MachineBasicBlock *MBB;
168  /// The MBB of the default bb, which is a successor of the range
169  /// check MBB.  This is when updating PHI nodes in successors.
170  MachineBasicBlock *Default;
171
172  JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D)
173      : Reg(R), JTI(J), MBB(M), Default(D) {}
174};
175struct JumpTableHeader {
176  APInt First;
177  APInt Last;
178  const Value *SValue;
179  MachineBasicBlock *HeaderBB;
180  bool Emitted;
181  bool OmitRangeCheck;
182
183  JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
184                  bool E = false)
185      : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
186        Emitted(E), OmitRangeCheck(false) {}
187};
188using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
189
190struct BitTestCase {
191  uint64_t Mask;
192  MachineBasicBlock *ThisBB;
193  MachineBasicBlock *TargetBB;
194  BranchProbability ExtraProb;
195
196  BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
197              BranchProbability Prob)
198      : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
199};
200
201using BitTestInfo = SmallVector<BitTestCase, 3>;
202
203struct BitTestBlock {
204  APInt First;
205  APInt Range;
206  const Value *SValue;
207  unsigned Reg;
208  MVT RegVT;
209  bool Emitted;
210  bool ContiguousRange;
211  MachineBasicBlock *Parent;
212  MachineBasicBlock *Default;
213  BitTestInfo Cases;
214  BranchProbability Prob;
215  BranchProbability DefaultProb;
216  bool OmitRangeCheck;
217
218  BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
219               bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
220               BitTestInfo C, BranchProbability Pr)
221      : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
222        RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
223        Cases(std::move(C)), Prob(Pr), OmitRangeCheck(false) {}
224};
225
226/// Return the range of values within a range.
227uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
228                           unsigned Last);
229
230/// Return the number of cases within a range.
231uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
232                              unsigned First, unsigned Last);
233
234struct SwitchWorkListItem {
235  MachineBasicBlock *MBB;
236  CaseClusterIt FirstCluster;
237  CaseClusterIt LastCluster;
238  const ConstantInt *GE;
239  const ConstantInt *LT;
240  BranchProbability DefaultProb;
241};
242using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
243
244class SwitchLowering {
245public:
246  SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
247
248  void init(const TargetLowering &tli, const TargetMachine &tm,
249            const DataLayout &dl) {
250    TLI = &tli;
251    TM = &tm;
252    DL = &dl;
253  }
254
255  /// Vector of CaseBlock structures used to communicate SwitchInst code
256  /// generation information.
257  std::vector<CaseBlock> SwitchCases;
258
259  /// Vector of JumpTable structures used to communicate SwitchInst code
260  /// generation information.
261  std::vector<JumpTableBlock> JTCases;
262
263  /// Vector of BitTestBlock structures used to communicate SwitchInst code
264  /// generation information.
265  std::vector<BitTestBlock> BitTestCases;
266
267  void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
268                      MachineBasicBlock *DefaultMBB,
269                      ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
270
271  bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
272                      unsigned Last, const SwitchInst *SI,
273                      MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
274
275
276  void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
277
278  /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
279  /// decides it's not a good idea.
280  bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
281                     const SwitchInst *SI, CaseCluster &BTCluster);
282
283  virtual void addSuccessorWithProb(
284      MachineBasicBlock *Src, MachineBasicBlock *Dst,
285      BranchProbability Prob = BranchProbability::getUnknown()) = 0;
286
287  virtual ~SwitchLowering() = default;
288
289private:
290  const TargetLowering *TLI;
291  const TargetMachine *TM;
292  const DataLayout *DL;
293  FunctionLoweringInfo &FuncInfo;
294};
295
296} // namespace SwitchCG
297} // namespace llvm
298
299#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
300