1//===- SwitchLoweringUtils.cpp - Switch Lowering --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains switch inst lowering optimizations and utilities for
10// codegen, so that it can be used for both SelectionDAG and GlobalISel.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/CodeGen/SwitchLoweringUtils.h"
15#include "llvm/CodeGen/FunctionLoweringInfo.h"
16#include "llvm/CodeGen/MachineJumpTableInfo.h"
17#include "llvm/CodeGen/TargetLowering.h"
18#include "llvm/Target/TargetMachine.h"
19
20using namespace llvm;
21using namespace SwitchCG;
22
23uint64_t SwitchCG::getJumpTableRange(const CaseClusterVector &Clusters,
24                                     unsigned First, unsigned Last) {
25  assert(Last >= First);
26  const APInt &LowCase = Clusters[First].Low->getValue();
27  const APInt &HighCase = Clusters[Last].High->getValue();
28  assert(LowCase.getBitWidth() == HighCase.getBitWidth());
29
30  // FIXME: A range of consecutive cases has 100% density, but only requires one
31  // comparison to lower. We should discriminate against such consecutive ranges
32  // in jump tables.
33  return (HighCase - LowCase).getLimitedValue((UINT64_MAX - 1) / 100) + 1;
34}
35
36uint64_t
37SwitchCG::getJumpTableNumCases(const SmallVectorImpl<unsigned> &TotalCases,
38                               unsigned First, unsigned Last) {
39  assert(Last >= First);
40  assert(TotalCases[Last] >= TotalCases[First]);
41  uint64_t NumCases =
42      TotalCases[Last] - (First == 0 ? 0 : TotalCases[First - 1]);
43  return NumCases;
44}
45
46void SwitchCG::SwitchLowering::findJumpTables(CaseClusterVector &Clusters,
47                                              const SwitchInst *SI,
48                                              MachineBasicBlock *DefaultMBB,
49                                              ProfileSummaryInfo *PSI,
50                                              BlockFrequencyInfo *BFI) {
51#ifndef NDEBUG
52  // Clusters must be non-empty, sorted, and only contain Range clusters.
53  assert(!Clusters.empty());
54  for (CaseCluster &C : Clusters)
55    assert(C.Kind == CC_Range);
56  for (unsigned i = 1, e = Clusters.size(); i < e; ++i)
57    assert(Clusters[i - 1].High->getValue().slt(Clusters[i].Low->getValue()));
58#endif
59
60  assert(TLI && "TLI not set!");
61  if (!TLI->areJTsAllowed(SI->getParent()->getParent()))
62    return;
63
64  const unsigned MinJumpTableEntries = TLI->getMinimumJumpTableEntries();
65  const unsigned SmallNumberOfEntries = MinJumpTableEntries / 2;
66
67  // Bail if not enough cases.
68  const int64_t N = Clusters.size();
69  if (N < 2 || N < MinJumpTableEntries)
70    return;
71
72  // Accumulated number of cases in each cluster and those prior to it.
73  SmallVector<unsigned, 8> TotalCases(N);
74  for (unsigned i = 0; i < N; ++i) {
75    const APInt &Hi = Clusters[i].High->getValue();
76    const APInt &Lo = Clusters[i].Low->getValue();
77    TotalCases[i] = (Hi - Lo).getLimitedValue() + 1;
78    if (i != 0)
79      TotalCases[i] += TotalCases[i - 1];
80  }
81
82  uint64_t Range = getJumpTableRange(Clusters,0, N - 1);
83  uint64_t NumCases = getJumpTableNumCases(TotalCases, 0, N - 1);
84  assert(NumCases < UINT64_MAX / 100);
85  assert(Range >= NumCases);
86
87  // Cheap case: the whole range may be suitable for jump table.
88  if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
89    CaseCluster JTCluster;
90    if (buildJumpTable(Clusters, 0, N - 1, SI, DefaultMBB, JTCluster)) {
91      Clusters[0] = JTCluster;
92      Clusters.resize(1);
93      return;
94    }
95  }
96
97  // The algorithm below is not suitable for -O0.
98  if (TM->getOptLevel() == CodeGenOpt::None)
99    return;
100
101  // Split Clusters into minimum number of dense partitions. The algorithm uses
102  // the same idea as Kannan & Proebsting "Correction to 'Producing Good Code
103  // for the Case Statement'" (1994), but builds the MinPartitions array in
104  // reverse order to make it easier to reconstruct the partitions in ascending
105  // order. In the choice between two optimal partitionings, it picks the one
106  // which yields more jump tables.
107
108  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
109  SmallVector<unsigned, 8> MinPartitions(N);
110  // LastElement[i] is the last element of the partition starting at i.
111  SmallVector<unsigned, 8> LastElement(N);
112  // PartitionsScore[i] is used to break ties when choosing between two
113  // partitionings resulting in the same number of partitions.
114  SmallVector<unsigned, 8> PartitionsScore(N);
115  // For PartitionsScore, a small number of comparisons is considered as good as
116  // a jump table and a single comparison is considered better than a jump
117  // table.
118  enum PartitionScores : unsigned {
119    NoTable = 0,
120    Table = 1,
121    FewCases = 1,
122    SingleCase = 2
123  };
124
125  // Base case: There is only one way to partition Clusters[N-1].
126  MinPartitions[N - 1] = 1;
127  LastElement[N - 1] = N - 1;
128  PartitionsScore[N - 1] = PartitionScores::SingleCase;
129
130  // Note: loop indexes are signed to avoid underflow.
131  for (int64_t i = N - 2; i >= 0; i--) {
132    // Find optimal partitioning of Clusters[i..N-1].
133    // Baseline: Put Clusters[i] into a partition on its own.
134    MinPartitions[i] = MinPartitions[i + 1] + 1;
135    LastElement[i] = i;
136    PartitionsScore[i] = PartitionsScore[i + 1] + PartitionScores::SingleCase;
137
138    // Search for a solution that results in fewer partitions.
139    for (int64_t j = N - 1; j > i; j--) {
140      // Try building a partition from Clusters[i..j].
141      Range = getJumpTableRange(Clusters, i, j);
142      NumCases = getJumpTableNumCases(TotalCases, i, j);
143      assert(NumCases < UINT64_MAX / 100);
144      assert(Range >= NumCases);
145
146      if (TLI->isSuitableForJumpTable(SI, NumCases, Range, PSI, BFI)) {
147        unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
148        unsigned Score = j == N - 1 ? 0 : PartitionsScore[j + 1];
149        int64_t NumEntries = j - i + 1;
150
151        if (NumEntries == 1)
152          Score += PartitionScores::SingleCase;
153        else if (NumEntries <= SmallNumberOfEntries)
154          Score += PartitionScores::FewCases;
155        else if (NumEntries >= MinJumpTableEntries)
156          Score += PartitionScores::Table;
157
158        // If this leads to fewer partitions, or to the same number of
159        // partitions with better score, it is a better partitioning.
160        if (NumPartitions < MinPartitions[i] ||
161            (NumPartitions == MinPartitions[i] && Score > PartitionsScore[i])) {
162          MinPartitions[i] = NumPartitions;
163          LastElement[i] = j;
164          PartitionsScore[i] = Score;
165        }
166      }
167    }
168  }
169
170  // Iterate over the partitions, replacing some with jump tables in-place.
171  unsigned DstIndex = 0;
172  for (unsigned First = 0, Last; First < N; First = Last + 1) {
173    Last = LastElement[First];
174    assert(Last >= First);
175    assert(DstIndex <= First);
176    unsigned NumClusters = Last - First + 1;
177
178    CaseCluster JTCluster;
179    if (NumClusters >= MinJumpTableEntries &&
180        buildJumpTable(Clusters, First, Last, SI, DefaultMBB, JTCluster)) {
181      Clusters[DstIndex++] = JTCluster;
182    } else {
183      for (unsigned I = First; I <= Last; ++I)
184        std::memmove(&Clusters[DstIndex++], &Clusters[I], sizeof(Clusters[I]));
185    }
186  }
187  Clusters.resize(DstIndex);
188}
189
190bool SwitchCG::SwitchLowering::buildJumpTable(const CaseClusterVector &Clusters,
191                                              unsigned First, unsigned Last,
192                                              const SwitchInst *SI,
193                                              MachineBasicBlock *DefaultMBB,
194                                              CaseCluster &JTCluster) {
195  assert(First <= Last);
196
197  auto Prob = BranchProbability::getZero();
198  unsigned NumCmps = 0;
199  std::vector<MachineBasicBlock*> Table;
200  DenseMap<MachineBasicBlock*, BranchProbability> JTProbs;
201
202  // Initialize probabilities in JTProbs.
203  for (unsigned I = First; I <= Last; ++I)
204    JTProbs[Clusters[I].MBB] = BranchProbability::getZero();
205
206  for (unsigned I = First; I <= Last; ++I) {
207    assert(Clusters[I].Kind == CC_Range);
208    Prob += Clusters[I].Prob;
209    const APInt &Low = Clusters[I].Low->getValue();
210    const APInt &High = Clusters[I].High->getValue();
211    NumCmps += (Low == High) ? 1 : 2;
212    if (I != First) {
213      // Fill the gap between this and the previous cluster.
214      const APInt &PreviousHigh = Clusters[I - 1].High->getValue();
215      assert(PreviousHigh.slt(Low));
216      uint64_t Gap = (Low - PreviousHigh).getLimitedValue() - 1;
217      for (uint64_t J = 0; J < Gap; J++)
218        Table.push_back(DefaultMBB);
219    }
220    uint64_t ClusterSize = (High - Low).getLimitedValue() + 1;
221    for (uint64_t J = 0; J < ClusterSize; ++J)
222      Table.push_back(Clusters[I].MBB);
223    JTProbs[Clusters[I].MBB] += Clusters[I].Prob;
224  }
225
226  unsigned NumDests = JTProbs.size();
227  if (TLI->isSuitableForBitTests(NumDests, NumCmps,
228                                 Clusters[First].Low->getValue(),
229                                 Clusters[Last].High->getValue(), *DL)) {
230    // Clusters[First..Last] should be lowered as bit tests instead.
231    return false;
232  }
233
234  // Create the MBB that will load from and jump through the table.
235  // Note: We create it here, but it's not inserted into the function yet.
236  MachineFunction *CurMF = FuncInfo.MF;
237  MachineBasicBlock *JumpTableMBB =
238      CurMF->CreateMachineBasicBlock(SI->getParent());
239
240  // Add successors. Note: use table order for determinism.
241  SmallPtrSet<MachineBasicBlock *, 8> Done;
242  for (MachineBasicBlock *Succ : Table) {
243    if (Done.count(Succ))
244      continue;
245    addSuccessorWithProb(JumpTableMBB, Succ, JTProbs[Succ]);
246    Done.insert(Succ);
247  }
248  JumpTableMBB->normalizeSuccProbs();
249
250  unsigned JTI = CurMF->getOrCreateJumpTableInfo(TLI->getJumpTableEncoding())
251                     ->createJumpTableIndex(Table);
252
253  // Set up the jump table info.
254  JumpTable JT(-1U, JTI, JumpTableMBB, nullptr);
255  JumpTableHeader JTH(Clusters[First].Low->getValue(),
256                      Clusters[Last].High->getValue(), SI->getCondition(),
257                      nullptr, false);
258  JTCases.emplace_back(std::move(JTH), std::move(JT));
259
260  JTCluster = CaseCluster::jumpTable(Clusters[First].Low, Clusters[Last].High,
261                                     JTCases.size() - 1, Prob);
262  return true;
263}
264
265void SwitchCG::SwitchLowering::findBitTestClusters(CaseClusterVector &Clusters,
266                                                   const SwitchInst *SI) {
267  // Partition Clusters into as few subsets as possible, where each subset has a
268  // range that fits in a machine word and has <= 3 unique destinations.
269
270#ifndef NDEBUG
271  // Clusters must be sorted and contain Range or JumpTable clusters.
272  assert(!Clusters.empty());
273  assert(Clusters[0].Kind == CC_Range || Clusters[0].Kind == CC_JumpTable);
274  for (const CaseCluster &C : Clusters)
275    assert(C.Kind == CC_Range || C.Kind == CC_JumpTable);
276  for (unsigned i = 1; i < Clusters.size(); ++i)
277    assert(Clusters[i-1].High->getValue().slt(Clusters[i].Low->getValue()));
278#endif
279
280  // The algorithm below is not suitable for -O0.
281  if (TM->getOptLevel() == CodeGenOpt::None)
282    return;
283
284  // If target does not have legal shift left, do not emit bit tests at all.
285  EVT PTy = TLI->getPointerTy(*DL);
286  if (!TLI->isOperationLegal(ISD::SHL, PTy))
287    return;
288
289  int BitWidth = PTy.getSizeInBits();
290  const int64_t N = Clusters.size();
291
292  // MinPartitions[i] is the minimum nbr of partitions of Clusters[i..N-1].
293  SmallVector<unsigned, 8> MinPartitions(N);
294  // LastElement[i] is the last element of the partition starting at i.
295  SmallVector<unsigned, 8> LastElement(N);
296
297  // FIXME: This might not be the best algorithm for finding bit test clusters.
298
299  // Base case: There is only one way to partition Clusters[N-1].
300  MinPartitions[N - 1] = 1;
301  LastElement[N - 1] = N - 1;
302
303  // Note: loop indexes are signed to avoid underflow.
304  for (int64_t i = N - 2; i >= 0; --i) {
305    // Find optimal partitioning of Clusters[i..N-1].
306    // Baseline: Put Clusters[i] into a partition on its own.
307    MinPartitions[i] = MinPartitions[i + 1] + 1;
308    LastElement[i] = i;
309
310    // Search for a solution that results in fewer partitions.
311    // Note: the search is limited by BitWidth, reducing time complexity.
312    for (int64_t j = std::min(N - 1, i + BitWidth - 1); j > i; --j) {
313      // Try building a partition from Clusters[i..j].
314
315      // Check the range.
316      if (!TLI->rangeFitsInWord(Clusters[i].Low->getValue(),
317                                Clusters[j].High->getValue(), *DL))
318        continue;
319
320      // Check nbr of destinations and cluster types.
321      // FIXME: This works, but doesn't seem very efficient.
322      bool RangesOnly = true;
323      BitVector Dests(FuncInfo.MF->getNumBlockIDs());
324      for (int64_t k = i; k <= j; k++) {
325        if (Clusters[k].Kind != CC_Range) {
326          RangesOnly = false;
327          break;
328        }
329        Dests.set(Clusters[k].MBB->getNumber());
330      }
331      if (!RangesOnly || Dests.count() > 3)
332        break;
333
334      // Check if it's a better partition.
335      unsigned NumPartitions = 1 + (j == N - 1 ? 0 : MinPartitions[j + 1]);
336      if (NumPartitions < MinPartitions[i]) {
337        // Found a better partition.
338        MinPartitions[i] = NumPartitions;
339        LastElement[i] = j;
340      }
341    }
342  }
343
344  // Iterate over the partitions, replacing with bit-test clusters in-place.
345  unsigned DstIndex = 0;
346  for (unsigned First = 0, Last; First < N; First = Last + 1) {
347    Last = LastElement[First];
348    assert(First <= Last);
349    assert(DstIndex <= First);
350
351    CaseCluster BitTestCluster;
352    if (buildBitTests(Clusters, First, Last, SI, BitTestCluster)) {
353      Clusters[DstIndex++] = BitTestCluster;
354    } else {
355      size_t NumClusters = Last - First + 1;
356      std::memmove(&Clusters[DstIndex], &Clusters[First],
357                   sizeof(Clusters[0]) * NumClusters);
358      DstIndex += NumClusters;
359    }
360  }
361  Clusters.resize(DstIndex);
362}
363
364bool SwitchCG::SwitchLowering::buildBitTests(CaseClusterVector &Clusters,
365                                             unsigned First, unsigned Last,
366                                             const SwitchInst *SI,
367                                             CaseCluster &BTCluster) {
368  assert(First <= Last);
369  if (First == Last)
370    return false;
371
372  BitVector Dests(FuncInfo.MF->getNumBlockIDs());
373  unsigned NumCmps = 0;
374  for (int64_t I = First; I <= Last; ++I) {
375    assert(Clusters[I].Kind == CC_Range);
376    Dests.set(Clusters[I].MBB->getNumber());
377    NumCmps += (Clusters[I].Low == Clusters[I].High) ? 1 : 2;
378  }
379  unsigned NumDests = Dests.count();
380
381  APInt Low = Clusters[First].Low->getValue();
382  APInt High = Clusters[Last].High->getValue();
383  assert(Low.slt(High));
384
385  if (!TLI->isSuitableForBitTests(NumDests, NumCmps, Low, High, *DL))
386    return false;
387
388  APInt LowBound;
389  APInt CmpRange;
390
391  const int BitWidth = TLI->getPointerTy(*DL).getSizeInBits();
392  assert(TLI->rangeFitsInWord(Low, High, *DL) &&
393         "Case range must fit in bit mask!");
394
395  // Check if the clusters cover a contiguous range such that no value in the
396  // range will jump to the default statement.
397  bool ContiguousRange = true;
398  for (int64_t I = First + 1; I <= Last; ++I) {
399    if (Clusters[I].Low->getValue() != Clusters[I - 1].High->getValue() + 1) {
400      ContiguousRange = false;
401      break;
402    }
403  }
404
405  if (Low.isStrictlyPositive() && High.slt(BitWidth)) {
406    // Optimize the case where all the case values fit in a word without having
407    // to subtract minValue. In this case, we can optimize away the subtraction.
408    LowBound = APInt::getNullValue(Low.getBitWidth());
409    CmpRange = High;
410    ContiguousRange = false;
411  } else {
412    LowBound = Low;
413    CmpRange = High - Low;
414  }
415
416  CaseBitsVector CBV;
417  auto TotalProb = BranchProbability::getZero();
418  for (unsigned i = First; i <= Last; ++i) {
419    // Find the CaseBits for this destination.
420    unsigned j;
421    for (j = 0; j < CBV.size(); ++j)
422      if (CBV[j].BB == Clusters[i].MBB)
423        break;
424    if (j == CBV.size())
425      CBV.push_back(
426          CaseBits(0, Clusters[i].MBB, 0, BranchProbability::getZero()));
427    CaseBits *CB = &CBV[j];
428
429    // Update Mask, Bits and ExtraProb.
430    uint64_t Lo = (Clusters[i].Low->getValue() - LowBound).getZExtValue();
431    uint64_t Hi = (Clusters[i].High->getValue() - LowBound).getZExtValue();
432    assert(Hi >= Lo && Hi < 64 && "Invalid bit case!");
433    CB->Mask |= (-1ULL >> (63 - (Hi - Lo))) << Lo;
434    CB->Bits += Hi - Lo + 1;
435    CB->ExtraProb += Clusters[i].Prob;
436    TotalProb += Clusters[i].Prob;
437  }
438
439  BitTestInfo BTI;
440  llvm::sort(CBV, [](const CaseBits &a, const CaseBits &b) {
441    // Sort by probability first, number of bits second, bit mask third.
442    if (a.ExtraProb != b.ExtraProb)
443      return a.ExtraProb > b.ExtraProb;
444    if (a.Bits != b.Bits)
445      return a.Bits > b.Bits;
446    return a.Mask < b.Mask;
447  });
448
449  for (auto &CB : CBV) {
450    MachineBasicBlock *BitTestBB =
451        FuncInfo.MF->CreateMachineBasicBlock(SI->getParent());
452    BTI.push_back(BitTestCase(CB.Mask, BitTestBB, CB.BB, CB.ExtraProb));
453  }
454  BitTestCases.emplace_back(std::move(LowBound), std::move(CmpRange),
455                            SI->getCondition(), -1U, MVT::Other, false,
456                            ContiguousRange, nullptr, nullptr, std::move(BTI),
457                            TotalProb);
458
459  BTCluster = CaseCluster::bitTests(Clusters[First].Low, Clusters[Last].High,
460                                    BitTestCases.size() - 1, TotalProb);
461  return true;
462}
463
464void SwitchCG::sortAndRangeify(CaseClusterVector &Clusters) {
465#ifndef NDEBUG
466  for (const CaseCluster &CC : Clusters)
467    assert(CC.Low == CC.High && "Input clusters must be single-case");
468#endif
469
470  llvm::sort(Clusters, [](const CaseCluster &a, const CaseCluster &b) {
471    return a.Low->getValue().slt(b.Low->getValue());
472  });
473
474  // Merge adjacent clusters with the same destination.
475  const unsigned N = Clusters.size();
476  unsigned DstIndex = 0;
477  for (unsigned SrcIndex = 0; SrcIndex < N; ++SrcIndex) {
478    CaseCluster &CC = Clusters[SrcIndex];
479    const ConstantInt *CaseVal = CC.Low;
480    MachineBasicBlock *Succ = CC.MBB;
481
482    if (DstIndex != 0 && Clusters[DstIndex - 1].MBB == Succ &&
483        (CaseVal->getValue() - Clusters[DstIndex - 1].High->getValue()) == 1) {
484      // If this case has the same successor and is a neighbour, merge it into
485      // the previous cluster.
486      Clusters[DstIndex - 1].High = CaseVal;
487      Clusters[DstIndex - 1].Prob += CC.Prob;
488    } else {
489      std::memmove(&Clusters[DstIndex++], &Clusters[SrcIndex],
490                   sizeof(Clusters[SrcIndex]));
491    }
492  }
493  Clusters.resize(DstIndex);
494}
495