1//===- SwitchLoweringUtils.h - Switch Lowering ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
10#define LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
11
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/CodeGen/ISDOpcodes.h"
14#include "llvm/CodeGen/SelectionDAGNodes.h"
15#include "llvm/IR/InstrTypes.h"
16#include "llvm/Support/BranchProbability.h"
17#include <vector>
18
19namespace llvm {
20
21class BlockFrequencyInfo;
22class ConstantInt;
23class FunctionLoweringInfo;
24class MachineBasicBlock;
25class ProfileSummaryInfo;
26class TargetLowering;
27class TargetMachine;
28
29namespace SwitchCG {
30
31enum CaseClusterKind {
32  /// A cluster of adjacent case labels with the same destination, or just one
33  /// case.
34  CC_Range,
35  /// A cluster of cases suitable for jump table lowering.
36  CC_JumpTable,
37  /// A cluster of cases suitable for bit test lowering.
38  CC_BitTests
39};
40
41/// A cluster of case labels.
42struct CaseCluster {
43  CaseClusterKind Kind;
44  const ConstantInt *Low, *High;
45  union {
46    MachineBasicBlock *MBB;
47    unsigned JTCasesIndex;
48    unsigned BTCasesIndex;
49  };
50  BranchProbability Prob;
51
52  static CaseCluster range(const ConstantInt *Low, const ConstantInt *High,
53                           MachineBasicBlock *MBB, BranchProbability Prob) {
54    CaseCluster C;
55    C.Kind = CC_Range;
56    C.Low = Low;
57    C.High = High;
58    C.MBB = MBB;
59    C.Prob = Prob;
60    return C;
61  }
62
63  static CaseCluster jumpTable(const ConstantInt *Low, const ConstantInt *High,
64                               unsigned JTCasesIndex, BranchProbability Prob) {
65    CaseCluster C;
66    C.Kind = CC_JumpTable;
67    C.Low = Low;
68    C.High = High;
69    C.JTCasesIndex = JTCasesIndex;
70    C.Prob = Prob;
71    return C;
72  }
73
74  static CaseCluster bitTests(const ConstantInt *Low, const ConstantInt *High,
75                              unsigned BTCasesIndex, BranchProbability Prob) {
76    CaseCluster C;
77    C.Kind = CC_BitTests;
78    C.Low = Low;
79    C.High = High;
80    C.BTCasesIndex = BTCasesIndex;
81    C.Prob = Prob;
82    return C;
83  }
84};
85
86using CaseClusterVector = std::vector<CaseCluster>;
87using CaseClusterIt = CaseClusterVector::iterator;
88
89/// Sort Clusters and merge adjacent cases.
90void sortAndRangeify(CaseClusterVector &Clusters);
91
92struct CaseBits {
93  uint64_t Mask = 0;
94  MachineBasicBlock *BB = nullptr;
95  unsigned Bits = 0;
96  BranchProbability ExtraProb;
97
98  CaseBits() = default;
99  CaseBits(uint64_t mask, MachineBasicBlock *bb, unsigned bits,
100           BranchProbability Prob)
101      : Mask(mask), BB(bb), Bits(bits), ExtraProb(Prob) {}
102};
103
104using CaseBitsVector = std::vector<CaseBits>;
105
106/// This structure is used to communicate between SelectionDAGBuilder and
107/// SDISel for the code generation of additional basic blocks needed by
108/// multi-case switch statements.
109struct CaseBlock {
110  // For the GISel interface.
111  struct PredInfoPair {
112    CmpInst::Predicate Pred;
113    // Set when no comparison should be emitted.
114    bool NoCmp;
115  };
116  union {
117    // The condition code to use for the case block's setcc node.
118    // Besides the integer condition codes, this can also be SETTRUE, in which
119    // case no comparison gets emitted.
120    ISD::CondCode CC;
121    struct PredInfoPair PredInfo;
122  };
123
124  // The LHS/MHS/RHS of the comparison to emit.
125  // Emit by default LHS op RHS. MHS is used for range comparisons:
126  // If MHS is not null: (LHS <= MHS) and (MHS <= RHS).
127  const Value *CmpLHS, *CmpMHS, *CmpRHS;
128
129  // The block to branch to if the setcc is true/false.
130  MachineBasicBlock *TrueBB, *FalseBB;
131
132  // The block into which to emit the code for the setcc and branches.
133  MachineBasicBlock *ThisBB;
134
135  /// The debug location of the instruction this CaseBlock was
136  /// produced from.
137  SDLoc DL;
138  DebugLoc DbgLoc;
139
140  // Branch weights.
141  BranchProbability TrueProb, FalseProb;
142
143  // Constructor for SelectionDAG.
144  CaseBlock(ISD::CondCode cc, const Value *cmplhs, const Value *cmprhs,
145            const Value *cmpmiddle, MachineBasicBlock *truebb,
146            MachineBasicBlock *falsebb, MachineBasicBlock *me, SDLoc dl,
147            BranchProbability trueprob = BranchProbability::getUnknown(),
148            BranchProbability falseprob = BranchProbability::getUnknown())
149      : CC(cc), CmpLHS(cmplhs), CmpMHS(cmpmiddle), CmpRHS(cmprhs),
150        TrueBB(truebb), FalseBB(falsebb), ThisBB(me), DL(dl),
151        TrueProb(trueprob), FalseProb(falseprob) {}
152
153  // Constructor for GISel.
154  CaseBlock(CmpInst::Predicate pred, bool nocmp, const Value *cmplhs,
155            const Value *cmprhs, const Value *cmpmiddle,
156            MachineBasicBlock *truebb, MachineBasicBlock *falsebb,
157            MachineBasicBlock *me, DebugLoc dl,
158            BranchProbability trueprob = BranchProbability::getUnknown(),
159            BranchProbability falseprob = BranchProbability::getUnknown())
160      : PredInfo({pred, nocmp}), CmpLHS(cmplhs), CmpMHS(cmpmiddle),
161        CmpRHS(cmprhs), TrueBB(truebb), FalseBB(falsebb), ThisBB(me),
162        DbgLoc(dl), TrueProb(trueprob), FalseProb(falseprob) {}
163};
164
165struct JumpTable {
166  /// The virtual register containing the index of the jump table entry
167  /// to jump to.
168  unsigned Reg;
169  /// The JumpTableIndex for this jump table in the function.
170  unsigned JTI;
171  /// The MBB into which to emit the code for the indirect jump.
172  MachineBasicBlock *MBB;
173  /// The MBB of the default bb, which is a successor of the range
174  /// check MBB.  This is when updating PHI nodes in successors.
175  MachineBasicBlock *Default;
176
177  /// The debug location of the instruction this JumpTable was produced from.
178  std::optional<SDLoc> SL; // For SelectionDAG
179
180  JumpTable(unsigned R, unsigned J, MachineBasicBlock *M, MachineBasicBlock *D,
181            std::optional<SDLoc> SL)
182      : Reg(R), JTI(J), MBB(M), Default(D), SL(SL) {}
183};
184struct JumpTableHeader {
185  APInt First;
186  APInt Last;
187  const Value *SValue;
188  MachineBasicBlock *HeaderBB;
189  bool Emitted;
190  bool FallthroughUnreachable = false;
191
192  JumpTableHeader(APInt F, APInt L, const Value *SV, MachineBasicBlock *H,
193                  bool E = false)
194      : First(std::move(F)), Last(std::move(L)), SValue(SV), HeaderBB(H),
195        Emitted(E) {}
196};
197using JumpTableBlock = std::pair<JumpTableHeader, JumpTable>;
198
199struct BitTestCase {
200  uint64_t Mask;
201  MachineBasicBlock *ThisBB;
202  MachineBasicBlock *TargetBB;
203  BranchProbability ExtraProb;
204
205  BitTestCase(uint64_t M, MachineBasicBlock *T, MachineBasicBlock *Tr,
206              BranchProbability Prob)
207      : Mask(M), ThisBB(T), TargetBB(Tr), ExtraProb(Prob) {}
208};
209
210using BitTestInfo = SmallVector<BitTestCase, 3>;
211
212struct BitTestBlock {
213  APInt First;
214  APInt Range;
215  const Value *SValue;
216  unsigned Reg;
217  MVT RegVT;
218  bool Emitted;
219  bool ContiguousRange;
220  MachineBasicBlock *Parent;
221  MachineBasicBlock *Default;
222  BitTestInfo Cases;
223  BranchProbability Prob;
224  BranchProbability DefaultProb;
225  bool FallthroughUnreachable = false;
226
227  BitTestBlock(APInt F, APInt R, const Value *SV, unsigned Rg, MVT RgVT, bool E,
228               bool CR, MachineBasicBlock *P, MachineBasicBlock *D,
229               BitTestInfo C, BranchProbability Pr)
230      : First(std::move(F)), Range(std::move(R)), SValue(SV), Reg(Rg),
231        RegVT(RgVT), Emitted(E), ContiguousRange(CR), Parent(P), Default(D),
232        Cases(std::move(C)), Prob(Pr) {}
233};
234
235/// Return the range of values within a range.
236uint64_t getJumpTableRange(const CaseClusterVector &Clusters, unsigned First,
237                           unsigned Last);
238
239/// Return the number of cases within a range.
240uint64_t getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
241                              unsigned First, unsigned Last);
242
243struct SwitchWorkListItem {
244  MachineBasicBlock *MBB = nullptr;
245  CaseClusterIt FirstCluster;
246  CaseClusterIt LastCluster;
247  const ConstantInt *GE = nullptr;
248  const ConstantInt *LT = nullptr;
249  BranchProbability DefaultProb;
250};
251using SwitchWorkList = SmallVector<SwitchWorkListItem, 4>;
252
253class SwitchLowering {
254public:
255  SwitchLowering(FunctionLoweringInfo &funcinfo) : FuncInfo(funcinfo) {}
256
257  void init(const TargetLowering &tli, const TargetMachine &tm,
258            const DataLayout &dl) {
259    TLI = &tli;
260    TM = &tm;
261    DL = &dl;
262  }
263
264  /// Vector of CaseBlock structures used to communicate SwitchInst code
265  /// generation information.
266  std::vector<CaseBlock> SwitchCases;
267
268  /// Vector of JumpTable structures used to communicate SwitchInst code
269  /// generation information.
270  std::vector<JumpTableBlock> JTCases;
271
272  /// Vector of BitTestBlock structures used to communicate SwitchInst code
273  /// generation information.
274  std::vector<BitTestBlock> BitTestCases;
275
276  void findJumpTables(CaseClusterVector &Clusters, const SwitchInst *SI,
277                      std::optional<SDLoc> SL, MachineBasicBlock *DefaultMBB,
278                      ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI);
279
280  bool buildJumpTable(const CaseClusterVector &Clusters, unsigned First,
281                      unsigned Last, const SwitchInst *SI,
282                      const std::optional<SDLoc> &SL,
283                      MachineBasicBlock *DefaultMBB, CaseCluster &JTCluster);
284
285  void findBitTestClusters(CaseClusterVector &Clusters, const SwitchInst *SI);
286
287  /// Build a bit test cluster from Clusters[First..Last]. Returns false if it
288  /// decides it's not a good idea.
289  bool buildBitTests(CaseClusterVector &Clusters, unsigned First, unsigned Last,
290                     const SwitchInst *SI, CaseCluster &BTCluster);
291
292  virtual void addSuccessorWithProb(
293      MachineBasicBlock *Src, MachineBasicBlock *Dst,
294      BranchProbability Prob = BranchProbability::getUnknown()) = 0;
295
296  /// Determine the rank by weight of CC in [First,Last]. If CC has more weight
297  /// than each cluster in the range, its rank is 0.
298  unsigned caseClusterRank(const CaseCluster &CC, CaseClusterIt First,
299                           CaseClusterIt Last);
300
301  struct SplitWorkItemInfo {
302    CaseClusterIt LastLeft;
303    CaseClusterIt FirstRight;
304    BranchProbability LeftProb;
305    BranchProbability RightProb;
306  };
307  /// Compute information to balance the tree based on branch probabilities to
308  /// create a near-optimal (in terms of search time given key frequency) binary
309  /// search tree. See e.g. Kurt Mehlhorn "Nearly Optimal Binary Search Trees"
310  /// (1975).
311  SplitWorkItemInfo computeSplitWorkItemInfo(const SwitchWorkListItem &W);
312  virtual ~SwitchLowering() = default;
313
314private:
315  const TargetLowering *TLI = nullptr;
316  const TargetMachine *TM = nullptr;
317  const DataLayout *DL = nullptr;
318  FunctionLoweringInfo &FuncInfo;
319};
320
321} // namespace SwitchCG
322} // namespace llvm
323
324#endif // LLVM_CODEGEN_SWITCHLOWERINGUTILS_H
325