154552Sroberto//===- InductiveRangeCheckElimination.cpp - -------------------------------===//
254552Sroberto//
354552Sroberto// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
454552Sroberto// See https://llvm.org/LICENSE.txt for license information.
554552Sroberto// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
654552Sroberto//
754552Sroberto//===----------------------------------------------------------------------===//
854552Sroberto//
9// The InductiveRangeCheckElimination pass splits a loop's iteration space into
10// three disjoint ranges.  It does that in a way such that the loop running in
11// the middle loop provably does not need range checks. As an example, it will
12// convert
13//
14//   len = < known positive >
15//   for (i = 0; i < n; i++) {
16//     if (0 <= i && i < len) {
17//       do_something();
18//     } else {
19//       throw_out_of_bounds();
20//     }
21//   }
22//
23// to
24//
25//   len = < known positive >
26//   limit = smin(n, len)
27//   // no first segment
28//   for (i = 0; i < limit; i++) {
29//     if (0 <= i && i < len) { // this check is fully redundant
30//       do_something();
31//     } else {
32//       throw_out_of_bounds();
33//     }
34//   }
35//   for (i = limit; i < n; i++) {
36//     if (0 <= i && i < len) {
37//       do_something();
38//     } else {
39//       throw_out_of_bounds();
40//     }
41//   }
42//
43//===----------------------------------------------------------------------===//
44
45#include "llvm/Transforms/Scalar/InductiveRangeCheckElimination.h"
46#include "llvm/ADT/APInt.h"
47#include "llvm/ADT/ArrayRef.h"
48#include "llvm/ADT/PriorityWorklist.h"
49#include "llvm/ADT/SmallPtrSet.h"
50#include "llvm/ADT/SmallVector.h"
51#include "llvm/ADT/StringRef.h"
52#include "llvm/ADT/Twine.h"
53#include "llvm/Analysis/BlockFrequencyInfo.h"
54#include "llvm/Analysis/BranchProbabilityInfo.h"
55#include "llvm/Analysis/LoopAnalysisManager.h"
56#include "llvm/Analysis/LoopInfo.h"
57#include "llvm/Analysis/ScalarEvolution.h"
58#include "llvm/Analysis/ScalarEvolutionExpressions.h"
59#include "llvm/IR/BasicBlock.h"
60#include "llvm/IR/CFG.h"
61#include "llvm/IR/Constants.h"
62#include "llvm/IR/DerivedTypes.h"
63#include "llvm/IR/Dominators.h"
64#include "llvm/IR/Function.h"
65#include "llvm/IR/IRBuilder.h"
66#include "llvm/IR/InstrTypes.h"
67#include "llvm/IR/Instructions.h"
68#include "llvm/IR/Metadata.h"
69#include "llvm/IR/Module.h"
70#include "llvm/IR/PatternMatch.h"
71#include "llvm/IR/Type.h"
72#include "llvm/IR/Use.h"
73#include "llvm/IR/User.h"
74#include "llvm/IR/Value.h"
75#include "llvm/InitializePasses.h"
76#include "llvm/Pass.h"
77#include "llvm/Support/BranchProbability.h"
78#include "llvm/Support/Casting.h"
79#include "llvm/Support/CommandLine.h"
80#include "llvm/Support/Compiler.h"
81#include "llvm/Support/Debug.h"
82#include "llvm/Support/ErrorHandling.h"
83#include "llvm/Support/raw_ostream.h"
84#include "llvm/Transforms/Scalar.h"
85#include "llvm/Transforms/Utils/Cloning.h"
86#include "llvm/Transforms/Utils/LoopSimplify.h"
87#include "llvm/Transforms/Utils/LoopUtils.h"
88#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
89#include "llvm/Transforms/Utils/ValueMapper.h"
90#include <algorithm>
91#include <cassert>
92#include <iterator>
93#include <limits>
94#include <optional>
95#include <utility>
96#include <vector>
97
98using namespace llvm;
99using namespace llvm::PatternMatch;
100
101static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden,
102                                        cl::init(64));
103
104static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden,
105                                       cl::init(false));
106
107static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden,
108                                      cl::init(false));
109
110static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks",
111                                             cl::Hidden, cl::init(false));
112
113static cl::opt<unsigned> MinRuntimeIterations("irce-min-runtime-iterations",
114                                              cl::Hidden, cl::init(10));
115
116static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch",
117                                                 cl::Hidden, cl::init(true));
118
119static cl::opt<bool> AllowNarrowLatchCondition(
120    "irce-allow-narrow-latch", cl::Hidden, cl::init(true),
121    cl::desc("If set to true, IRCE may eliminate wide range checks in loops "
122             "with narrow latch condition."));
123
124static const char *ClonedLoopTag = "irce.loop.clone";
125
126#define DEBUG_TYPE "irce"
127
128namespace {
129
130/// An inductive range check is conditional branch in a loop with
131///
132///  1. a very cold successor (i.e. the branch jumps to that successor very
133///     rarely)
134///
135///  and
136///
137///  2. a condition that is provably true for some contiguous range of values
138///     taken by the containing loop's induction variable.
139///
140class InductiveRangeCheck {
141
142  const SCEV *Begin = nullptr;
143  const SCEV *Step = nullptr;
144  const SCEV *End = nullptr;
145  Use *CheckUse = nullptr;
146
147  static bool parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE,
148                                  Value *&Index, Value *&Length,
149                                  bool &IsSigned);
150
151  static void
152  extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse,
153                             SmallVectorImpl<InductiveRangeCheck> &Checks,
154                             SmallPtrSetImpl<Value *> &Visited);
155
156public:
157  const SCEV *getBegin() const { return Begin; }
158  const SCEV *getStep() const { return Step; }
159  const SCEV *getEnd() const { return End; }
160
161  void print(raw_ostream &OS) const {
162    OS << "InductiveRangeCheck:\n";
163    OS << "  Begin: ";
164    Begin->print(OS);
165    OS << "  Step: ";
166    Step->print(OS);
167    OS << "  End: ";
168    End->print(OS);
169    OS << "\n  CheckUse: ";
170    getCheckUse()->getUser()->print(OS);
171    OS << " Operand: " << getCheckUse()->getOperandNo() << "\n";
172  }
173
174  LLVM_DUMP_METHOD
175  void dump() {
176    print(dbgs());
177  }
178
179  Use *getCheckUse() const { return CheckUse; }
180
181  /// Represents an signed integer range [Range.getBegin(), Range.getEnd()).  If
182  /// R.getEnd() le R.getBegin(), then R denotes the empty range.
183
184  class Range {
185    const SCEV *Begin;
186    const SCEV *End;
187
188  public:
189    Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) {
190      assert(Begin->getType() == End->getType() && "ill-typed range!");
191    }
192
193    Type *getType() const { return Begin->getType(); }
194    const SCEV *getBegin() const { return Begin; }
195    const SCEV *getEnd() const { return End; }
196    bool isEmpty(ScalarEvolution &SE, bool IsSigned) const {
197      if (Begin == End)
198        return true;
199      if (IsSigned)
200        return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End);
201      else
202        return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End);
203    }
204  };
205
206  /// This is the value the condition of the branch needs to evaluate to for the
207  /// branch to take the hot successor (see (1) above).
208  bool getPassingDirection() { return true; }
209
210  /// Computes a range for the induction variable (IndVar) in which the range
211  /// check is redundant and can be constant-folded away.  The induction
212  /// variable is not required to be the canonical {0,+,1} induction variable.
213  std::optional<Range> computeSafeIterationSpace(ScalarEvolution &SE,
214                                                 const SCEVAddRecExpr *IndVar,
215                                                 bool IsLatchSigned) const;
216
217  /// Parse out a set of inductive range checks from \p BI and append them to \p
218  /// Checks.
219  ///
220  /// NB! There may be conditions feeding into \p BI that aren't inductive range
221  /// checks, and hence don't end up in \p Checks.
222  static void
223  extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE,
224                               BranchProbabilityInfo *BPI,
225                               SmallVectorImpl<InductiveRangeCheck> &Checks);
226};
227
228struct LoopStructure;
229
230class InductiveRangeCheckElimination {
231  ScalarEvolution &SE;
232  BranchProbabilityInfo *BPI;
233  DominatorTree &DT;
234  LoopInfo &LI;
235
236  using GetBFIFunc =
237      std::optional<llvm::function_ref<llvm::BlockFrequencyInfo &()>>;
238  GetBFIFunc GetBFI;
239
240  // Returns true if it is profitable to do a transform basing on estimation of
241  // number of iterations.
242  bool isProfitableToTransform(const Loop &L, LoopStructure &LS);
243
244public:
245  InductiveRangeCheckElimination(ScalarEvolution &SE,
246                                 BranchProbabilityInfo *BPI, DominatorTree &DT,
247                                 LoopInfo &LI, GetBFIFunc GetBFI = std::nullopt)
248      : SE(SE), BPI(BPI), DT(DT), LI(LI), GetBFI(GetBFI) {}
249
250  bool run(Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop);
251};
252
253class IRCELegacyPass : public FunctionPass {
254public:
255  static char ID;
256
257  IRCELegacyPass() : FunctionPass(ID) {
258    initializeIRCELegacyPassPass(*PassRegistry::getPassRegistry());
259  }
260
261  void getAnalysisUsage(AnalysisUsage &AU) const override {
262    AU.addRequired<BranchProbabilityInfoWrapperPass>();
263    AU.addRequired<DominatorTreeWrapperPass>();
264    AU.addPreserved<DominatorTreeWrapperPass>();
265    AU.addRequired<LoopInfoWrapperPass>();
266    AU.addPreserved<LoopInfoWrapperPass>();
267    AU.addRequired<ScalarEvolutionWrapperPass>();
268    AU.addPreserved<ScalarEvolutionWrapperPass>();
269  }
270
271  bool runOnFunction(Function &F) override;
272};
273
274} // end anonymous namespace
275
276char IRCELegacyPass::ID = 0;
277
278INITIALIZE_PASS_BEGIN(IRCELegacyPass, "irce",
279                      "Inductive range check elimination", false, false)
280INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
281INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
282INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
283INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
284INITIALIZE_PASS_END(IRCELegacyPass, "irce", "Inductive range check elimination",
285                    false, false)
286
287/// Parse a single ICmp instruction, `ICI`, into a range check.  If `ICI` cannot
288/// be interpreted as a range check, return false and set `Index` and `Length`
289/// to `nullptr`.  Otherwise set `Index` to the value being range checked, and
290/// set `Length` to the upper limit `Index` is being range checked.
291bool
292InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI,
293                                         ScalarEvolution &SE, Value *&Index,
294                                         Value *&Length, bool &IsSigned) {
295  auto IsLoopInvariant = [&SE, L](Value *V) {
296    return SE.isLoopInvariant(SE.getSCEV(V), L);
297  };
298
299  ICmpInst::Predicate Pred = ICI->getPredicate();
300  Value *LHS = ICI->getOperand(0);
301  Value *RHS = ICI->getOperand(1);
302
303  switch (Pred) {
304  default:
305    return false;
306
307  case ICmpInst::ICMP_SLE:
308    std::swap(LHS, RHS);
309    [[fallthrough]];
310  case ICmpInst::ICMP_SGE:
311    IsSigned = true;
312    if (match(RHS, m_ConstantInt<0>())) {
313      Index = LHS;
314      return true; // Lower.
315    }
316    return false;
317
318  case ICmpInst::ICMP_SLT:
319    std::swap(LHS, RHS);
320    [[fallthrough]];
321  case ICmpInst::ICMP_SGT:
322    IsSigned = true;
323    if (match(RHS, m_ConstantInt<-1>())) {
324      Index = LHS;
325      return true; // Lower.
326    }
327
328    if (IsLoopInvariant(LHS)) {
329      Index = RHS;
330      Length = LHS;
331      return true; // Upper.
332    }
333    return false;
334
335  case ICmpInst::ICMP_ULT:
336    std::swap(LHS, RHS);
337    [[fallthrough]];
338  case ICmpInst::ICMP_UGT:
339    IsSigned = false;
340    if (IsLoopInvariant(LHS)) {
341      Index = RHS;
342      Length = LHS;
343      return true; // Both lower and upper.
344    }
345    return false;
346  }
347
348  llvm_unreachable("default clause returns!");
349}
350
351void InductiveRangeCheck::extractRangeChecksFromCond(
352    Loop *L, ScalarEvolution &SE, Use &ConditionUse,
353    SmallVectorImpl<InductiveRangeCheck> &Checks,
354    SmallPtrSetImpl<Value *> &Visited) {
355  Value *Condition = ConditionUse.get();
356  if (!Visited.insert(Condition).second)
357    return;
358
359  // TODO: Do the same for OR, XOR, NOT etc?
360  if (match(Condition, m_LogicalAnd(m_Value(), m_Value()))) {
361    extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0),
362                               Checks, Visited);
363    extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1),
364                               Checks, Visited);
365    return;
366  }
367
368  ICmpInst *ICI = dyn_cast<ICmpInst>(Condition);
369  if (!ICI)
370    return;
371
372  Value *Length = nullptr, *Index;
373  bool IsSigned;
374  if (!parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned))
375    return;
376
377  const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index));
378  bool IsAffineIndex =
379      IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine();
380
381  if (!IsAffineIndex)
382    return;
383
384  const SCEV *End = nullptr;
385  // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L".
386  // We can potentially do much better here.
387  if (Length)
388    End = SE.getSCEV(Length);
389  else {
390    // So far we can only reach this point for Signed range check. This may
391    // change in future. In this case we will need to pick Unsigned max for the
392    // unsigned range check.
393    unsigned BitWidth = cast<IntegerType>(IndexAddRec->getType())->getBitWidth();
394    const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
395    End = SIntMax;
396  }
397
398  InductiveRangeCheck IRC;
399  IRC.End = End;
400  IRC.Begin = IndexAddRec->getStart();
401  IRC.Step = IndexAddRec->getStepRecurrence(SE);
402  IRC.CheckUse = &ConditionUse;
403  Checks.push_back(IRC);
404}
405
406void InductiveRangeCheck::extractRangeChecksFromBranch(
407    BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo *BPI,
408    SmallVectorImpl<InductiveRangeCheck> &Checks) {
409  if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch())
410    return;
411
412  BranchProbability LikelyTaken(15, 16);
413
414  if (!SkipProfitabilityChecks && BPI &&
415      BPI->getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken)
416    return;
417
418  SmallPtrSet<Value *, 8> Visited;
419  InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0),
420                                                  Checks, Visited);
421}
422
423// Add metadata to the loop L to disable loop optimizations. Callers need to
424// confirm that optimizing loop L is not beneficial.
425static void DisableAllLoopOptsOnLoop(Loop &L) {
426  // We do not care about any existing loopID related metadata for L, since we
427  // are setting all loop metadata to false.
428  LLVMContext &Context = L.getHeader()->getContext();
429  // Reserve first location for self reference to the LoopID metadata node.
430  MDNode *Dummy = MDNode::get(Context, {});
431  MDNode *DisableUnroll = MDNode::get(
432      Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
433  Metadata *FalseVal =
434      ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
435  MDNode *DisableVectorize = MDNode::get(
436      Context,
437      {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
438  MDNode *DisableLICMVersioning = MDNode::get(
439      Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
440  MDNode *DisableDistribution= MDNode::get(
441      Context,
442      {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
443  MDNode *NewLoopID =
444      MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
445                            DisableLICMVersioning, DisableDistribution});
446  // Set operand 0 to refer to the loop id itself.
447  NewLoopID->replaceOperandWith(0, NewLoopID);
448  L.setLoopID(NewLoopID);
449}
450
451namespace {
452
453// Keeps track of the structure of a loop.  This is similar to llvm::Loop,
454// except that it is more lightweight and can track the state of a loop through
455// changing and potentially invalid IR.  This structure also formalizes the
456// kinds of loops we can deal with -- ones that have a single latch that is also
457// an exiting block *and* have a canonical induction variable.
458struct LoopStructure {
459  const char *Tag = "";
460
461  BasicBlock *Header = nullptr;
462  BasicBlock *Latch = nullptr;
463
464  // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th
465  // successor is `LatchExit', the exit block of the loop.
466  BranchInst *LatchBr = nullptr;
467  BasicBlock *LatchExit = nullptr;
468  unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max();
469
470  // The loop represented by this instance of LoopStructure is semantically
471  // equivalent to:
472  //
473  // intN_ty inc = IndVarIncreasing ? 1 : -1;
474  // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT;
475  //
476  // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase)
477  //   ... body ...
478
479  Value *IndVarBase = nullptr;
480  Value *IndVarStart = nullptr;
481  Value *IndVarStep = nullptr;
482  Value *LoopExitAt = nullptr;
483  bool IndVarIncreasing = false;
484  bool IsSignedPredicate = true;
485
486  LoopStructure() = default;
487
488  template <typename M> LoopStructure map(M Map) const {
489    LoopStructure Result;
490    Result.Tag = Tag;
491    Result.Header = cast<BasicBlock>(Map(Header));
492    Result.Latch = cast<BasicBlock>(Map(Latch));
493    Result.LatchBr = cast<BranchInst>(Map(LatchBr));
494    Result.LatchExit = cast<BasicBlock>(Map(LatchExit));
495    Result.LatchBrExitIdx = LatchBrExitIdx;
496    Result.IndVarBase = Map(IndVarBase);
497    Result.IndVarStart = Map(IndVarStart);
498    Result.IndVarStep = Map(IndVarStep);
499    Result.LoopExitAt = Map(LoopExitAt);
500    Result.IndVarIncreasing = IndVarIncreasing;
501    Result.IsSignedPredicate = IsSignedPredicate;
502    return Result;
503  }
504
505  static std::optional<LoopStructure> parseLoopStructure(ScalarEvolution &,
506                                                         Loop &, const char *&);
507};
508
509/// This class is used to constrain loops to run within a given iteration space.
510/// The algorithm this class implements is given a Loop and a range [Begin,
511/// End).  The algorithm then tries to break out a "main loop" out of the loop
512/// it is given in a way that the "main loop" runs with the induction variable
513/// in a subset of [Begin, End).  The algorithm emits appropriate pre and post
514/// loops to run any remaining iterations.  The pre loop runs any iterations in
515/// which the induction variable is < Begin, and the post loop runs any
516/// iterations in which the induction variable is >= End.
517class LoopConstrainer {
518  // The representation of a clone of the original loop we started out with.
519  struct ClonedLoop {
520    // The cloned blocks
521    std::vector<BasicBlock *> Blocks;
522
523    // `Map` maps values in the clonee into values in the cloned version
524    ValueToValueMapTy Map;
525
526    // An instance of `LoopStructure` for the cloned loop
527    LoopStructure Structure;
528  };
529
530  // Result of rewriting the range of a loop.  See changeIterationSpaceEnd for
531  // more details on what these fields mean.
532  struct RewrittenRangeInfo {
533    BasicBlock *PseudoExit = nullptr;
534    BasicBlock *ExitSelector = nullptr;
535    std::vector<PHINode *> PHIValuesAtPseudoExit;
536    PHINode *IndVarEnd = nullptr;
537
538    RewrittenRangeInfo() = default;
539  };
540
541  // Calculated subranges we restrict the iteration space of the main loop to.
542  // See the implementation of `calculateSubRanges' for more details on how
543  // these fields are computed.  `LowLimit` is std::nullopt if there is no
544  // restriction on low end of the restricted iteration space of the main loop.
545  // `HighLimit` is std::nullopt if there is no restriction on high end of the
546  // restricted iteration space of the main loop.
547
548  struct SubRanges {
549    std::optional<const SCEV *> LowLimit;
550    std::optional<const SCEV *> HighLimit;
551  };
552
553  // Compute a safe set of limits for the main loop to run in -- effectively the
554  // intersection of `Range' and the iteration space of the original loop.
555  // Return std::nullopt if unable to compute the set of subranges.
556  std::optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const;
557
558  // Clone `OriginalLoop' and return the result in CLResult.  The IR after
559  // running `cloneLoop' is well formed except for the PHI nodes in CLResult --
560  // the PHI nodes say that there is an incoming edge from `OriginalPreheader`
561  // but there is no such edge.
562  void cloneLoop(ClonedLoop &CLResult, const char *Tag) const;
563
564  // Create the appropriate loop structure needed to describe a cloned copy of
565  // `Original`.  The clone is described by `VM`.
566  Loop *createClonedLoopStructure(Loop *Original, Loop *Parent,
567                                  ValueToValueMapTy &VM, bool IsSubloop);
568
569  // Rewrite the iteration space of the loop denoted by (LS, Preheader). The
570  // iteration space of the rewritten loop ends at ExitLoopAt.  The start of the
571  // iteration space is not changed.  `ExitLoopAt' is assumed to be slt
572  // `OriginalHeaderCount'.
573  //
574  // If there are iterations left to execute, control is made to jump to
575  // `ContinuationBlock', otherwise they take the normal loop exit.  The
576  // returned `RewrittenRangeInfo' object is populated as follows:
577  //
578  //  .PseudoExit is a basic block that unconditionally branches to
579  //      `ContinuationBlock'.
580  //
581  //  .ExitSelector is a basic block that decides, on exit from the loop,
582  //      whether to branch to the "true" exit or to `PseudoExit'.
583  //
584  //  .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value
585  //      for each PHINode in the loop header on taking the pseudo exit.
586  //
587  // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate
588  // preheader because it is made to branch to the loop header only
589  // conditionally.
590  RewrittenRangeInfo
591  changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader,
592                          Value *ExitLoopAt,
593                          BasicBlock *ContinuationBlock) const;
594
595  // The loop denoted by `LS' has `OldPreheader' as its preheader.  This
596  // function creates a new preheader for `LS' and returns it.
597  BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader,
598                              const char *Tag) const;
599
600  // `ContinuationBlockAndPreheader' was the continuation block for some call to
601  // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'.
602  // This function rewrites the PHI nodes in `LS.Header' to start with the
603  // correct value.
604  void rewriteIncomingValuesForPHIs(
605      LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader,
606      const LoopConstrainer::RewrittenRangeInfo &RRI) const;
607
608  // Even though we do not preserve any passes at this time, we at least need to
609  // keep the parent loop structure consistent.  The `LPPassManager' seems to
610  // verify this after running a loop pass.  This function adds the list of
611  // blocks denoted by BBs to this loops parent loop if required.
612  void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs);
613
614  // Some global state.
615  Function &F;
616  LLVMContext &Ctx;
617  ScalarEvolution &SE;
618  DominatorTree &DT;
619  LoopInfo &LI;
620  function_ref<void(Loop *, bool)> LPMAddNewLoop;
621
622  // Information about the original loop we started out with.
623  Loop &OriginalLoop;
624
625  const SCEV *LatchTakenCount = nullptr;
626  BasicBlock *OriginalPreheader = nullptr;
627
628  // The preheader of the main loop.  This may or may not be different from
629  // `OriginalPreheader'.
630  BasicBlock *MainLoopPreheader = nullptr;
631
632  // The range we need to run the main loop in.
633  InductiveRangeCheck::Range Range;
634
635  // The structure of the main loop (see comment at the beginning of this class
636  // for a definition)
637  LoopStructure MainLoopStructure;
638
639public:
640  LoopConstrainer(Loop &L, LoopInfo &LI,
641                  function_ref<void(Loop *, bool)> LPMAddNewLoop,
642                  const LoopStructure &LS, ScalarEvolution &SE,
643                  DominatorTree &DT, InductiveRangeCheck::Range R)
644      : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()),
645        SE(SE), DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L),
646        Range(R), MainLoopStructure(LS) {}
647
648  // Entry point for the algorithm.  Returns true on success.
649  bool run();
650};
651
652} // end anonymous namespace
653
654/// Given a loop with an deccreasing induction variable, is it possible to
655/// safely calculate the bounds of a new loop using the given Predicate.
656static bool isSafeDecreasingBound(const SCEV *Start,
657                                  const SCEV *BoundSCEV, const SCEV *Step,
658                                  ICmpInst::Predicate Pred,
659                                  unsigned LatchBrExitIdx,
660                                  Loop *L, ScalarEvolution &SE) {
661  if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
662      Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
663    return false;
664
665  if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
666    return false;
667
668  assert(SE.isKnownNegative(Step) && "expecting negative step");
669
670  LLVM_DEBUG(dbgs() << "irce: isSafeDecreasingBound with:\n");
671  LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
672  LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
673  LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
674  LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
675                    << "\n");
676  LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
677
678  bool IsSigned = ICmpInst::isSigned(Pred);
679  // The predicate that we need to check that the induction variable lies
680  // within bounds.
681  ICmpInst::Predicate BoundPred =
682    IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
683
684  if (LatchBrExitIdx == 1)
685    return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
686
687  assert(LatchBrExitIdx == 0 &&
688         "LatchBrExitIdx should be either 0 or 1");
689
690  const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
691  unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
692  APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) :
693    APInt::getMinValue(BitWidth);
694  const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
695
696  const SCEV *MinusOne =
697    SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
698
699  return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
700         SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
701
702}
703
704/// Given a loop with an increasing induction variable, is it possible to
705/// safely calculate the bounds of a new loop using the given Predicate.
706static bool isSafeIncreasingBound(const SCEV *Start,
707                                  const SCEV *BoundSCEV, const SCEV *Step,
708                                  ICmpInst::Predicate Pred,
709                                  unsigned LatchBrExitIdx,
710                                  Loop *L, ScalarEvolution &SE) {
711  if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
712      Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
713    return false;
714
715  if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
716    return false;
717
718  LLVM_DEBUG(dbgs() << "irce: isSafeIncreasingBound with:\n");
719  LLVM_DEBUG(dbgs() << "irce: Start: " << *Start << "\n");
720  LLVM_DEBUG(dbgs() << "irce: Step: " << *Step << "\n");
721  LLVM_DEBUG(dbgs() << "irce: BoundSCEV: " << *BoundSCEV << "\n");
722  LLVM_DEBUG(dbgs() << "irce: Pred: " << ICmpInst::getPredicateName(Pred)
723                    << "\n");
724  LLVM_DEBUG(dbgs() << "irce: LatchExitBrIdx: " << LatchBrExitIdx << "\n");
725
726  bool IsSigned = ICmpInst::isSigned(Pred);
727  // The predicate that we need to check that the induction variable lies
728  // within bounds.
729  ICmpInst::Predicate BoundPred =
730      IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
731
732  if (LatchBrExitIdx == 1)
733    return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
734
735  assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
736
737  const SCEV *StepMinusOne =
738    SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
739  unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
740  APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) :
741    APInt::getMaxValue(BitWidth);
742  const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
743
744  return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
745                                      SE.getAddExpr(BoundSCEV, Step)) &&
746          SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
747}
748
749std::optional<LoopStructure>
750LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
751                                  const char *&FailureReason) {
752  if (!L.isLoopSimplifyForm()) {
753    FailureReason = "loop not in LoopSimplify form";
754    return std::nullopt;
755  }
756
757  BasicBlock *Latch = L.getLoopLatch();
758  assert(Latch && "Simplified loops only have one latch!");
759
760  if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
761    FailureReason = "loop has already been cloned";
762    return std::nullopt;
763  }
764
765  if (!L.isLoopExiting(Latch)) {
766    FailureReason = "no loop latch";
767    return std::nullopt;
768  }
769
770  BasicBlock *Header = L.getHeader();
771  BasicBlock *Preheader = L.getLoopPreheader();
772  if (!Preheader) {
773    FailureReason = "no preheader";
774    return std::nullopt;
775  }
776
777  BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
778  if (!LatchBr || LatchBr->isUnconditional()) {
779    FailureReason = "latch terminator not conditional branch";
780    return std::nullopt;
781  }
782
783  unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
784
785  ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
786  if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
787    FailureReason = "latch terminator branch not conditional on integral icmp";
788    return std::nullopt;
789  }
790
791  const SCEV *LatchCount = SE.getExitCount(&L, Latch);
792  if (isa<SCEVCouldNotCompute>(LatchCount)) {
793    FailureReason = "could not compute latch count";
794    return std::nullopt;
795  }
796
797  ICmpInst::Predicate Pred = ICI->getPredicate();
798  Value *LeftValue = ICI->getOperand(0);
799  const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
800  IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
801
802  Value *RightValue = ICI->getOperand(1);
803  const SCEV *RightSCEV = SE.getSCEV(RightValue);
804
805  // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
806  if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
807    if (isa<SCEVAddRecExpr>(RightSCEV)) {
808      std::swap(LeftSCEV, RightSCEV);
809      std::swap(LeftValue, RightValue);
810      Pred = ICmpInst::getSwappedPredicate(Pred);
811    } else {
812      FailureReason = "no add recurrences in the icmp";
813      return std::nullopt;
814    }
815  }
816
817  auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
818    if (AR->getNoWrapFlags(SCEV::FlagNSW))
819      return true;
820
821    IntegerType *Ty = cast<IntegerType>(AR->getType());
822    IntegerType *WideTy =
823        IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
824
825    const SCEVAddRecExpr *ExtendAfterOp =
826        dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
827    if (ExtendAfterOp) {
828      const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
829      const SCEV *ExtendedStep =
830          SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
831
832      bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
833                          ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
834
835      if (NoSignedWrap)
836        return true;
837    }
838
839    // We may have proved this when computing the sign extension above.
840    return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
841  };
842
843  // `ICI` is interpreted as taking the backedge if the *next* value of the
844  // induction variable satisfies some constraint.
845
846  const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
847  if (IndVarBase->getLoop() != &L) {
848    FailureReason = "LHS in cmp is not an AddRec for this loop";
849    return std::nullopt;
850  }
851  if (!IndVarBase->isAffine()) {
852    FailureReason = "LHS in icmp not induction variable";
853    return std::nullopt;
854  }
855  const SCEV* StepRec = IndVarBase->getStepRecurrence(SE);
856  if (!isa<SCEVConstant>(StepRec)) {
857    FailureReason = "LHS in icmp not induction variable";
858    return std::nullopt;
859  }
860  ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
861
862  if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
863    FailureReason = "LHS in icmp needs nsw for equality predicates";
864    return std::nullopt;
865  }
866
867  assert(!StepCI->isZero() && "Zero step?");
868  bool IsIncreasing = !StepCI->isNegative();
869  bool IsSignedPredicate;
870  const SCEV *StartNext = IndVarBase->getStart();
871  const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
872  const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
873  const SCEV *Step = SE.getSCEV(StepCI);
874
875  const SCEV *FixedRightSCEV = nullptr;
876
877  // If RightValue resides within loop (but still being loop invariant),
878  // regenerate it as preheader.
879  if (auto *I = dyn_cast<Instruction>(RightValue))
880    if (L.contains(I->getParent()))
881      FixedRightSCEV = RightSCEV;
882
883  if (IsIncreasing) {
884    bool DecreasedRightValueByOne = false;
885    if (StepCI->isOne()) {
886      // Try to turn eq/ne predicates to those we can work with.
887      if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
888        // while (++i != len) {         while (++i < len) {
889        //   ...                 --->     ...
890        // }                            }
891        // If both parts are known non-negative, it is profitable to use
892        // unsigned comparison in increasing loop. This allows us to make the
893        // comparison check against "RightSCEV + 1" more optimistic.
894        if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
895            isKnownNonNegativeInLoop(RightSCEV, &L, SE))
896          Pred = ICmpInst::ICMP_ULT;
897        else
898          Pred = ICmpInst::ICMP_SLT;
899      else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
900        // while (true) {               while (true) {
901        //   if (++i == len)     --->     if (++i > len - 1)
902        //     break;                       break;
903        //   ...                          ...
904        // }                            }
905        if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
906            cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/false)) {
907          Pred = ICmpInst::ICMP_UGT;
908          RightSCEV = SE.getMinusSCEV(RightSCEV,
909                                      SE.getOne(RightSCEV->getType()));
910          DecreasedRightValueByOne = true;
911        } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/true)) {
912          Pred = ICmpInst::ICMP_SGT;
913          RightSCEV = SE.getMinusSCEV(RightSCEV,
914                                      SE.getOne(RightSCEV->getType()));
915          DecreasedRightValueByOne = true;
916        }
917      }
918    }
919
920    bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
921    bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
922    bool FoundExpectedPred =
923        (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
924
925    if (!FoundExpectedPred) {
926      FailureReason = "expected icmp slt semantically, found something else";
927      return std::nullopt;
928    }
929
930    IsSignedPredicate = ICmpInst::isSigned(Pred);
931    if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
932      FailureReason = "unsigned latch conditions are explicitly prohibited";
933      return std::nullopt;
934    }
935
936    if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
937                               LatchBrExitIdx, &L, SE)) {
938      FailureReason = "Unsafe loop bounds";
939      return std::nullopt;
940    }
941    if (LatchBrExitIdx == 0) {
942      // We need to increase the right value unless we have already decreased
943      // it virtually when we replaced EQ with SGT.
944      if (!DecreasedRightValueByOne)
945        FixedRightSCEV =
946            SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
947    } else {
948      assert(!DecreasedRightValueByOne &&
949             "Right value can be decreased only for LatchBrExitIdx == 0!");
950    }
951  } else {
952    bool IncreasedRightValueByOne = false;
953    if (StepCI->isMinusOne()) {
954      // Try to turn eq/ne predicates to those we can work with.
955      if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
956        // while (--i != len) {         while (--i > len) {
957        //   ...                 --->     ...
958        // }                            }
959        // We intentionally don't turn the predicate into UGT even if we know
960        // that both operands are non-negative, because it will only pessimize
961        // our check against "RightSCEV - 1".
962        Pred = ICmpInst::ICMP_SGT;
963      else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
964        // while (true) {               while (true) {
965        //   if (--i == len)     --->     if (--i < len + 1)
966        //     break;                       break;
967        //   ...                          ...
968        // }                            }
969        if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
970            cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
971          Pred = ICmpInst::ICMP_ULT;
972          RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
973          IncreasedRightValueByOne = true;
974        } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
975          Pred = ICmpInst::ICMP_SLT;
976          RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
977          IncreasedRightValueByOne = true;
978        }
979      }
980    }
981
982    bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
983    bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
984
985    bool FoundExpectedPred =
986        (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
987
988    if (!FoundExpectedPred) {
989      FailureReason = "expected icmp sgt semantically, found something else";
990      return std::nullopt;
991    }
992
993    IsSignedPredicate =
994        Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
995
996    if (!IsSignedPredicate && !AllowUnsignedLatchCondition) {
997      FailureReason = "unsigned latch conditions are explicitly prohibited";
998      return std::nullopt;
999    }
1000
1001    if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
1002                               LatchBrExitIdx, &L, SE)) {
1003      FailureReason = "Unsafe bounds";
1004      return std::nullopt;
1005    }
1006
1007    if (LatchBrExitIdx == 0) {
1008      // We need to decrease the right value unless we have already increased
1009      // it virtually when we replaced EQ with SLT.
1010      if (!IncreasedRightValueByOne)
1011        FixedRightSCEV =
1012            SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
1013    } else {
1014      assert(!IncreasedRightValueByOne &&
1015             "Right value can be increased only for LatchBrExitIdx == 0!");
1016    }
1017  }
1018  BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
1019
1020  assert(SE.getLoopDisposition(LatchCount, &L) ==
1021             ScalarEvolution::LoopInvariant &&
1022         "loop variant exit count doesn't make sense!");
1023
1024  assert(!L.contains(LatchExit) && "expected an exit block!");
1025  const DataLayout &DL = Preheader->getModule()->getDataLayout();
1026  SCEVExpander Expander(SE, DL, "irce");
1027  Instruction *Ins = Preheader->getTerminator();
1028
1029  if (FixedRightSCEV)
1030    RightValue =
1031        Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
1032
1033  Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
1034  IndVarStartV->setName("indvar.start");
1035
1036  LoopStructure Result;
1037
1038  Result.Tag = "main";
1039  Result.Header = Header;
1040  Result.Latch = Latch;
1041  Result.LatchBr = LatchBr;
1042  Result.LatchExit = LatchExit;
1043  Result.LatchBrExitIdx = LatchBrExitIdx;
1044  Result.IndVarStart = IndVarStartV;
1045  Result.IndVarStep = StepCI;
1046  Result.IndVarBase = LeftValue;
1047  Result.IndVarIncreasing = IsIncreasing;
1048  Result.LoopExitAt = RightValue;
1049  Result.IsSignedPredicate = IsSignedPredicate;
1050
1051  FailureReason = nullptr;
1052
1053  return Result;
1054}
1055
1056/// If the type of \p S matches with \p Ty, return \p S. Otherwise, return
1057/// signed or unsigned extension of \p S to type \p Ty.
1058static const SCEV *NoopOrExtend(const SCEV *S, Type *Ty, ScalarEvolution &SE,
1059                                bool Signed) {
1060  return Signed ? SE.getNoopOrSignExtend(S, Ty) : SE.getNoopOrZeroExtend(S, Ty);
1061}
1062
1063std::optional<LoopConstrainer::SubRanges>
1064LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const {
1065  IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType());
1066
1067  auto *RTy = cast<IntegerType>(Range.getType());
1068
1069  // We only support wide range checks and narrow latches.
1070  if (!AllowNarrowLatchCondition && RTy != Ty)
1071    return std::nullopt;
1072  if (RTy->getBitWidth() < Ty->getBitWidth())
1073    return std::nullopt;
1074
1075  LoopConstrainer::SubRanges Result;
1076
1077  // I think we can be more aggressive here and make this nuw / nsw if the
1078  // addition that feeds into the icmp for the latch's terminating branch is nuw
1079  // / nsw.  In any case, a wrapping 2's complement addition is safe.
1080  const SCEV *Start = NoopOrExtend(SE.getSCEV(MainLoopStructure.IndVarStart),
1081                                   RTy, SE, IsSignedPredicate);
1082  const SCEV *End = NoopOrExtend(SE.getSCEV(MainLoopStructure.LoopExitAt), RTy,
1083                                 SE, IsSignedPredicate);
1084
1085  bool Increasing = MainLoopStructure.IndVarIncreasing;
1086
1087  // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or
1088  // [Smallest, GreatestSeen] is the range of values the induction variable
1089  // takes.
1090
1091  const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr;
1092
1093  const SCEV *One = SE.getOne(RTy);
1094  if (Increasing) {
1095    Smallest = Start;
1096    Greatest = End;
1097    // No overflow, because the range [Smallest, GreatestSeen] is not empty.
1098    GreatestSeen = SE.getMinusSCEV(End, One);
1099  } else {
1100    // These two computations may sign-overflow.  Here is why that is okay:
1101    //
1102    // We know that the induction variable does not sign-overflow on any
1103    // iteration except the last one, and it starts at `Start` and ends at
1104    // `End`, decrementing by one every time.
1105    //
1106    //  * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the
1107    //    induction variable is decreasing we know that that the smallest value
1108    //    the loop body is actually executed with is `INT_SMIN` == `Smallest`.
1109    //
1110    //  * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`.  In
1111    //    that case, `Clamp` will always return `Smallest` and
1112    //    [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`)
1113    //    will be an empty range.  Returning an empty range is always safe.
1114
1115    Smallest = SE.getAddExpr(End, One);
1116    Greatest = SE.getAddExpr(Start, One);
1117    GreatestSeen = Start;
1118  }
1119
1120  auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) {
1121    return IsSignedPredicate
1122               ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S))
1123               : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S));
1124  };
1125
1126  // In some cases we can prove that we don't need a pre or post loop.
1127  ICmpInst::Predicate PredLE =
1128      IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
1129  ICmpInst::Predicate PredLT =
1130      IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT;
1131
1132  bool ProvablyNoPreloop =
1133      SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest);
1134  if (!ProvablyNoPreloop)
1135    Result.LowLimit = Clamp(Range.getBegin());
1136
1137  bool ProvablyNoPostLoop =
1138      SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd());
1139  if (!ProvablyNoPostLoop)
1140    Result.HighLimit = Clamp(Range.getEnd());
1141
1142  return Result;
1143}
1144
1145void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
1146                                const char *Tag) const {
1147  for (BasicBlock *BB : OriginalLoop.getBlocks()) {
1148    BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
1149    Result.Blocks.push_back(Clone);
1150    Result.Map[BB] = Clone;
1151  }
1152
1153  auto GetClonedValue = [&Result](Value *V) {
1154    assert(V && "null values not in domain!");
1155    auto It = Result.Map.find(V);
1156    if (It == Result.Map.end())
1157      return V;
1158    return static_cast<Value *>(It->second);
1159  };
1160
1161  auto *ClonedLatch =
1162      cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
1163  ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
1164                                            MDNode::get(Ctx, {}));
1165
1166  Result.Structure = MainLoopStructure.map(GetClonedValue);
1167  Result.Structure.Tag = Tag;
1168
1169  for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
1170    BasicBlock *ClonedBB = Result.Blocks[i];
1171    BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
1172
1173    assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
1174
1175    for (Instruction &I : *ClonedBB)
1176      RemapInstruction(&I, Result.Map,
1177                       RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
1178
1179    // Exit blocks will now have one more predecessor and their PHI nodes need
1180    // to be edited to reflect that.  No phi nodes need to be introduced because
1181    // the loop is in LCSSA.
1182
1183    for (auto *SBB : successors(OriginalBB)) {
1184      if (OriginalLoop.contains(SBB))
1185        continue; // not an exit block
1186
1187      for (PHINode &PN : SBB->phis()) {
1188        Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
1189        PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
1190        SE.forgetValue(&PN);
1191      }
1192    }
1193  }
1194}
1195
1196LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
1197    const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
1198    BasicBlock *ContinuationBlock) const {
1199  // We start with a loop with a single latch:
1200  //
1201  //    +--------------------+
1202  //    |                    |
1203  //    |     preheader      |
1204  //    |                    |
1205  //    +--------+-----------+
1206  //             |      ----------------\
1207  //             |     /                |
1208  //    +--------v----v------+          |
1209  //    |                    |          |
1210  //    |      header        |          |
1211  //    |                    |          |
1212  //    +--------------------+          |
1213  //                                    |
1214  //            .....                   |
1215  //                                    |
1216  //    +--------------------+          |
1217  //    |                    |          |
1218  //    |       latch        >----------/
1219  //    |                    |
1220  //    +-------v------------+
1221  //            |
1222  //            |
1223  //            |   +--------------------+
1224  //            |   |                    |
1225  //            +--->   original exit    |
1226  //                |                    |
1227  //                +--------------------+
1228  //
1229  // We change the control flow to look like
1230  //
1231  //
1232  //    +--------------------+
1233  //    |                    |
1234  //    |     preheader      >-------------------------+
1235  //    |                    |                         |
1236  //    +--------v-----------+                         |
1237  //             |    /-------------+                  |
1238  //             |   /              |                  |
1239  //    +--------v--v--------+      |                  |
1240  //    |                    |      |                  |
1241  //    |      header        |      |   +--------+     |
1242  //    |                    |      |   |        |     |
1243  //    +--------------------+      |   |  +-----v-----v-----------+
1244  //                                |   |  |                       |
1245  //                                |   |  |     .pseudo.exit      |
1246  //                                |   |  |                       |
1247  //                                |   |  +-----------v-----------+
1248  //                                |   |              |
1249  //            .....               |   |              |
1250  //                                |   |     +--------v-------------+
1251  //    +--------------------+      |   |     |                      |
1252  //    |                    |      |   |     |   ContinuationBlock  |
1253  //    |       latch        >------+   |     |                      |
1254  //    |                    |          |     +----------------------+
1255  //    +---------v----------+          |
1256  //              |                     |
1257  //              |                     |
1258  //              |     +---------------^-----+
1259  //              |     |                     |
1260  //              +----->    .exit.selector   |
1261  //                    |                     |
1262  //                    +----------v----------+
1263  //                               |
1264  //     +--------------------+    |
1265  //     |                    |    |
1266  //     |   original exit    <----+
1267  //     |                    |
1268  //     +--------------------+
1269
1270  RewrittenRangeInfo RRI;
1271
1272  BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
1273  RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
1274                                        &F, BBInsertLocation);
1275  RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
1276                                      BBInsertLocation);
1277
1278  BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
1279  bool Increasing = LS.IndVarIncreasing;
1280  bool IsSignedPredicate = LS.IsSignedPredicate;
1281
1282  IRBuilder<> B(PreheaderJump);
1283  auto *RangeTy = Range.getBegin()->getType();
1284  auto NoopOrExt = [&](Value *V) {
1285    if (V->getType() == RangeTy)
1286      return V;
1287    return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
1288                             : B.CreateZExt(V, RangeTy, "wide." + V->getName());
1289  };
1290
1291  // EnterLoopCond - is it okay to start executing this `LS'?
1292  Value *EnterLoopCond = nullptr;
1293  auto Pred =
1294      Increasing
1295          ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
1296          : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
1297  Value *IndVarStart = NoopOrExt(LS.IndVarStart);
1298  EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
1299
1300  B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
1301  PreheaderJump->eraseFromParent();
1302
1303  LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
1304  B.SetInsertPoint(LS.LatchBr);
1305  Value *IndVarBase = NoopOrExt(LS.IndVarBase);
1306  Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
1307
1308  Value *CondForBranch = LS.LatchBrExitIdx == 1
1309                             ? TakeBackedgeLoopCond
1310                             : B.CreateNot(TakeBackedgeLoopCond);
1311
1312  LS.LatchBr->setCondition(CondForBranch);
1313
1314  B.SetInsertPoint(RRI.ExitSelector);
1315
1316  // IterationsLeft - are there any more iterations left, given the original
1317  // upper bound on the induction variable?  If not, we branch to the "real"
1318  // exit.
1319  Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
1320  Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
1321  B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
1322
1323  BranchInst *BranchToContinuation =
1324      BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
1325
1326  // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
1327  // each of the PHI nodes in the loop header.  This feeds into the initial
1328  // value of the same PHI nodes if/when we continue execution.
1329  for (PHINode &PN : LS.Header->phis()) {
1330    PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
1331                                      BranchToContinuation);
1332
1333    NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
1334    NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
1335                        RRI.ExitSelector);
1336    RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
1337  }
1338
1339  RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
1340                                  BranchToContinuation);
1341  RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
1342  RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
1343
1344  // The latch exit now has a branch from `RRI.ExitSelector' instead of
1345  // `LS.Latch'.  The PHI nodes need to be updated to reflect that.
1346  LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
1347
1348  return RRI;
1349}
1350
1351void LoopConstrainer::rewriteIncomingValuesForPHIs(
1352    LoopStructure &LS, BasicBlock *ContinuationBlock,
1353    const LoopConstrainer::RewrittenRangeInfo &RRI) const {
1354  unsigned PHIIndex = 0;
1355  for (PHINode &PN : LS.Header->phis())
1356    PN.setIncomingValueForBlock(ContinuationBlock,
1357                                RRI.PHIValuesAtPseudoExit[PHIIndex++]);
1358
1359  LS.IndVarStart = RRI.IndVarEnd;
1360}
1361
1362BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
1363                                             BasicBlock *OldPreheader,
1364                                             const char *Tag) const {
1365  BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
1366  BranchInst::Create(LS.Header, Preheader);
1367
1368  LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
1369
1370  return Preheader;
1371}
1372
1373void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
1374  Loop *ParentLoop = OriginalLoop.getParentLoop();
1375  if (!ParentLoop)
1376    return;
1377
1378  for (BasicBlock *BB : BBs)
1379    ParentLoop->addBasicBlockToLoop(BB, LI);
1380}
1381
1382Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
1383                                                 ValueToValueMapTy &VM,
1384                                                 bool IsSubloop) {
1385  Loop &New = *LI.AllocateLoop();
1386  if (Parent)
1387    Parent->addChildLoop(&New);
1388  else
1389    LI.addTopLevelLoop(&New);
1390  LPMAddNewLoop(&New, IsSubloop);
1391
1392  // Add all of the blocks in Original to the new loop.
1393  for (auto *BB : Original->blocks())
1394    if (LI.getLoopFor(BB) == Original)
1395      New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
1396
1397  // Add all of the subloops to the new loop.
1398  for (Loop *SubLoop : *Original)
1399    createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
1400
1401  return &New;
1402}
1403
1404bool LoopConstrainer::run() {
1405  BasicBlock *Preheader = nullptr;
1406  LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch);
1407  Preheader = OriginalLoop.getLoopPreheader();
1408  assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr &&
1409         "preconditions!");
1410
1411  OriginalPreheader = Preheader;
1412  MainLoopPreheader = Preheader;
1413
1414  bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
1415  std::optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate);
1416  if (!MaybeSR) {
1417    LLVM_DEBUG(dbgs() << "irce: could not compute subranges\n");
1418    return false;
1419  }
1420
1421  SubRanges SR = *MaybeSR;
1422  bool Increasing = MainLoopStructure.IndVarIncreasing;
1423  IntegerType *IVTy =
1424      cast<IntegerType>(Range.getBegin()->getType());
1425
1426  SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce");
1427  Instruction *InsertPt = OriginalPreheader->getTerminator();
1428
1429  // It would have been better to make `PreLoop' and `PostLoop'
1430  // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
1431  // constructor.
1432  ClonedLoop PreLoop, PostLoop;
1433  bool NeedsPreLoop =
1434      Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
1435  bool NeedsPostLoop =
1436      Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
1437
1438  Value *ExitPreLoopAt = nullptr;
1439  Value *ExitMainLoopAt = nullptr;
1440  const SCEVConstant *MinusOneS =
1441      cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
1442
1443  if (NeedsPreLoop) {
1444    const SCEV *ExitPreLoopAtSCEV = nullptr;
1445
1446    if (Increasing)
1447      ExitPreLoopAtSCEV = *SR.LowLimit;
1448    else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
1449                               IsSignedPredicate))
1450      ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
1451    else {
1452      LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1453                        << "preloop exit limit.  HighLimit = "
1454                        << *(*SR.HighLimit) << "\n");
1455      return false;
1456    }
1457
1458    if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
1459      LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1460                        << " preloop exit limit " << *ExitPreLoopAtSCEV
1461                        << " at block " << InsertPt->getParent()->getName()
1462                        << "\n");
1463      return false;
1464    }
1465
1466    ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
1467    ExitPreLoopAt->setName("exit.preloop.at");
1468  }
1469
1470  if (NeedsPostLoop) {
1471    const SCEV *ExitMainLoopAtSCEV = nullptr;
1472
1473    if (Increasing)
1474      ExitMainLoopAtSCEV = *SR.HighLimit;
1475    else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
1476                               IsSignedPredicate))
1477      ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
1478    else {
1479      LLVM_DEBUG(dbgs() << "irce: could not prove no-overflow when computing "
1480                        << "mainloop exit limit.  LowLimit = "
1481                        << *(*SR.LowLimit) << "\n");
1482      return false;
1483    }
1484
1485    if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
1486      LLVM_DEBUG(dbgs() << "irce: could not prove that it is safe to expand the"
1487                        << " main loop exit limit " << *ExitMainLoopAtSCEV
1488                        << " at block " << InsertPt->getParent()->getName()
1489                        << "\n");
1490      return false;
1491    }
1492
1493    ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
1494    ExitMainLoopAt->setName("exit.mainloop.at");
1495  }
1496
1497  // We clone these ahead of time so that we don't have to deal with changing
1498  // and temporarily invalid IR as we transform the loops.
1499  if (NeedsPreLoop)
1500    cloneLoop(PreLoop, "preloop");
1501  if (NeedsPostLoop)
1502    cloneLoop(PostLoop, "postloop");
1503
1504  RewrittenRangeInfo PreLoopRRI;
1505
1506  if (NeedsPreLoop) {
1507    Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
1508                                                  PreLoop.Structure.Header);
1509
1510    MainLoopPreheader =
1511        createPreheader(MainLoopStructure, Preheader, "mainloop");
1512    PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
1513                                         ExitPreLoopAt, MainLoopPreheader);
1514    rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
1515                                 PreLoopRRI);
1516  }
1517
1518  BasicBlock *PostLoopPreheader = nullptr;
1519  RewrittenRangeInfo PostLoopRRI;
1520
1521  if (NeedsPostLoop) {
1522    PostLoopPreheader =
1523        createPreheader(PostLoop.Structure, Preheader, "postloop");
1524    PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
1525                                          ExitMainLoopAt, PostLoopPreheader);
1526    rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
1527                                 PostLoopRRI);
1528  }
1529
1530  BasicBlock *NewMainLoopPreheader =
1531      MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
1532  BasicBlock *NewBlocks[] = {PostLoopPreheader,        PreLoopRRI.PseudoExit,
1533                             PreLoopRRI.ExitSelector,  PostLoopRRI.PseudoExit,
1534                             PostLoopRRI.ExitSelector, NewMainLoopPreheader};
1535
1536  // Some of the above may be nullptr, filter them out before passing to
1537  // addToParentLoopIfNeeded.
1538  auto NewBlocksEnd =
1539      std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
1540
1541  addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
1542
1543  DT.recalculate(F);
1544
1545  // We need to first add all the pre and post loop blocks into the loop
1546  // structures (as part of createClonedLoopStructure), and then update the
1547  // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
1548  // LI when LoopSimplifyForm is generated.
1549  Loop *PreL = nullptr, *PostL = nullptr;
1550  if (!PreLoop.Blocks.empty()) {
1551    PreL = createClonedLoopStructure(&OriginalLoop,
1552                                     OriginalLoop.getParentLoop(), PreLoop.Map,
1553                                     /* IsSubLoop */ false);
1554  }
1555
1556  if (!PostLoop.Blocks.empty()) {
1557    PostL =
1558        createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
1559                                  PostLoop.Map, /* IsSubLoop */ false);
1560  }
1561
1562  // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
1563  auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) {
1564    formLCSSARecursively(*L, DT, &LI, &SE);
1565    simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
1566    // Pre/post loops are slow paths, we do not need to perform any loop
1567    // optimizations on them.
1568    if (!IsOriginalLoop)
1569      DisableAllLoopOptsOnLoop(*L);
1570  };
1571  if (PreL)
1572    CanonicalizeLoop(PreL, false);
1573  if (PostL)
1574    CanonicalizeLoop(PostL, false);
1575  CanonicalizeLoop(&OriginalLoop, true);
1576
1577  return true;
1578}
1579
1580/// Computes and returns a range of values for the induction variable (IndVar)
1581/// in which the range check can be safely elided.  If it cannot compute such a
1582/// range, returns std::nullopt.
1583std::optional<InductiveRangeCheck::Range>
1584InductiveRangeCheck::computeSafeIterationSpace(ScalarEvolution &SE,
1585                                               const SCEVAddRecExpr *IndVar,
1586                                               bool IsLatchSigned) const {
1587  // We can deal when types of latch check and range checks don't match in case
1588  // if latch check is more narrow.
1589  auto *IVType = dyn_cast<IntegerType>(IndVar->getType());
1590  auto *RCType = dyn_cast<IntegerType>(getBegin()->getType());
1591  // Do not work with pointer types.
1592  if (!IVType || !RCType)
1593    return std::nullopt;
1594  if (IVType->getBitWidth() > RCType->getBitWidth())
1595    return std::nullopt;
1596  // IndVar is of the form "A + B * I" (where "I" is the canonical induction
1597  // variable, that may or may not exist as a real llvm::Value in the loop) and
1598  // this inductive range check is a range check on the "C + D * I" ("C" is
1599  // getBegin() and "D" is getStep()).  We rewrite the value being range
1600  // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA".
1601  //
1602  // The actual inequalities we solve are of the form
1603  //
1604  //   0 <= M + 1 * IndVar < L given L >= 0  (i.e. N == 1)
1605  //
1606  // Here L stands for upper limit of the safe iteration space.
1607  // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid
1608  // overflows when calculating (0 - M) and (L - M) we, depending on type of
1609  // IV's iteration space, limit the calculations by borders of the iteration
1610  // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0.
1611  // If we figured out that "anything greater than (-M) is safe", we strengthen
1612  // this to "everything greater than 0 is safe", assuming that values between
1613  // -M and 0 just do not exist in unsigned iteration space, and we don't want
1614  // to deal with overflown values.
1615
1616  if (!IndVar->isAffine())
1617    return std::nullopt;
1618
1619  const SCEV *A = NoopOrExtend(IndVar->getStart(), RCType, SE, IsLatchSigned);
1620  const SCEVConstant *B = dyn_cast<SCEVConstant>(
1621      NoopOrExtend(IndVar->getStepRecurrence(SE), RCType, SE, IsLatchSigned));
1622  if (!B)
1623    return std::nullopt;
1624  assert(!B->isZero() && "Recurrence with zero step?");
1625
1626  const SCEV *C = getBegin();
1627  const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep());
1628  if (D != B)
1629    return std::nullopt;
1630
1631  assert(!D->getValue()->isZero() && "Recurrence with zero step?");
1632  unsigned BitWidth = RCType->getBitWidth();
1633  const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth));
1634
1635  // Subtract Y from X so that it does not go through border of the IV
1636  // iteration space. Mathematically, it is equivalent to:
1637  //
1638  //    ClampedSubtract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX).        [1]
1639  //
1640  // In [1], 'X - Y' is a mathematical subtraction (result is not bounded to
1641  // any width of bit grid). But after we take min/max, the result is
1642  // guaranteed to be within [INT_MIN, INT_MAX].
1643  //
1644  // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min
1645  // values, depending on type of latch condition that defines IV iteration
1646  // space.
1647  auto ClampedSubtract = [&](const SCEV *X, const SCEV *Y) {
1648    // FIXME: The current implementation assumes that X is in [0, SINT_MAX].
1649    // This is required to ensure that SINT_MAX - X does not overflow signed and
1650    // that X - Y does not overflow unsigned if Y is negative. Can we lift this
1651    // restriction and make it work for negative X either?
1652    if (IsLatchSigned) {
1653      // X is a number from signed range, Y is interpreted as signed.
1654      // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only
1655      // thing we should care about is that we didn't cross SINT_MAX.
1656      // So, if Y is positive, we subtract Y safely.
1657      //   Rule 1: Y > 0 ---> Y.
1658      // If 0 <= -Y <= (SINT_MAX - X), we subtract Y safely.
1659      //   Rule 2: Y >=s (X - SINT_MAX) ---> Y.
1660      // If 0 <= (SINT_MAX - X) < -Y, we can only subtract (X - SINT_MAX).
1661      //   Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX).
1662      // It gives us smax(Y, X - SINT_MAX) to subtract in all cases.
1663      const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax);
1664      return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax),
1665                             SCEV::FlagNSW);
1666    } else
1667      // X is a number from unsigned range, Y is interpreted as signed.
1668      // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only
1669      // thing we should care about is that we didn't cross zero.
1670      // So, if Y is negative, we subtract Y safely.
1671      //   Rule 1: Y <s 0 ---> Y.
1672      // If 0 <= Y <= X, we subtract Y safely.
1673      //   Rule 2: Y <=s X ---> Y.
1674      // If 0 <= X < Y, we should stop at 0 and can only subtract X.
1675      //   Rule 3: Y >s X ---> X.
1676      // It gives us smin(X, Y) to subtract in all cases.
1677      return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW);
1678  };
1679  const SCEV *M = SE.getMinusSCEV(C, A);
1680  const SCEV *Zero = SE.getZero(M->getType());
1681
1682  // This function returns SCEV equal to 1 if X is non-negative 0 otherwise.
1683  auto SCEVCheckNonNegative = [&](const SCEV *X) {
1684    const Loop *L = IndVar->getLoop();
1685    const SCEV *One = SE.getOne(X->getType());
1686    // Can we trivially prove that X is a non-negative or negative value?
1687    if (isKnownNonNegativeInLoop(X, L, SE))
1688      return One;
1689    else if (isKnownNegativeInLoop(X, L, SE))
1690      return Zero;
1691    // If not, we will have to figure it out during the execution.
1692    // Function smax(smin(X, 0), -1) + 1 equals to 1 if X >= 0 and 0 if X < 0.
1693    const SCEV *NegOne = SE.getNegativeSCEV(One);
1694    return SE.getAddExpr(SE.getSMaxExpr(SE.getSMinExpr(X, Zero), NegOne), One);
1695  };
1696  // FIXME: Current implementation of ClampedSubtract implicitly assumes that
1697  // X is non-negative (in sense of a signed value). We need to re-implement
1698  // this function in a way that it will correctly handle negative X as well.
1699  // We use it twice: for X = 0 everything is fine, but for X = getEnd() we can
1700  // end up with a negative X and produce wrong results. So currently we ensure
1701  // that if getEnd() is negative then both ends of the safe range are zero.
1702  // Note that this may pessimize elimination of unsigned range checks against
1703  // negative values.
1704  const SCEV *REnd = getEnd();
1705  const SCEV *EndIsNonNegative = SCEVCheckNonNegative(REnd);
1706
1707  const SCEV *Begin = SE.getMulExpr(ClampedSubtract(Zero, M), EndIsNonNegative);
1708  const SCEV *End = SE.getMulExpr(ClampedSubtract(REnd, M), EndIsNonNegative);
1709  return InductiveRangeCheck::Range(Begin, End);
1710}
1711
1712static std::optional<InductiveRangeCheck::Range>
1713IntersectSignedRange(ScalarEvolution &SE,
1714                     const std::optional<InductiveRangeCheck::Range> &R1,
1715                     const InductiveRangeCheck::Range &R2) {
1716  if (R2.isEmpty(SE, /* IsSigned */ true))
1717    return std::nullopt;
1718  if (!R1)
1719    return R2;
1720  auto &R1Value = *R1;
1721  // We never return empty ranges from this function, and R1 is supposed to be
1722  // a result of intersection. Thus, R1 is never empty.
1723  assert(!R1Value.isEmpty(SE, /* IsSigned */ true) &&
1724         "We should never have empty R1!");
1725
1726  // TODO: we could widen the smaller range and have this work; but for now we
1727  // bail out to keep things simple.
1728  if (R1Value.getType() != R2.getType())
1729    return std::nullopt;
1730
1731  const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin());
1732  const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd());
1733
1734  // If the resulting range is empty, just return std::nullopt.
1735  auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1736  if (Ret.isEmpty(SE, /* IsSigned */ true))
1737    return std::nullopt;
1738  return Ret;
1739}
1740
1741static std::optional<InductiveRangeCheck::Range>
1742IntersectUnsignedRange(ScalarEvolution &SE,
1743                       const std::optional<InductiveRangeCheck::Range> &R1,
1744                       const InductiveRangeCheck::Range &R2) {
1745  if (R2.isEmpty(SE, /* IsSigned */ false))
1746    return std::nullopt;
1747  if (!R1)
1748    return R2;
1749  auto &R1Value = *R1;
1750  // We never return empty ranges from this function, and R1 is supposed to be
1751  // a result of intersection. Thus, R1 is never empty.
1752  assert(!R1Value.isEmpty(SE, /* IsSigned */ false) &&
1753         "We should never have empty R1!");
1754
1755  // TODO: we could widen the smaller range and have this work; but for now we
1756  // bail out to keep things simple.
1757  if (R1Value.getType() != R2.getType())
1758    return std::nullopt;
1759
1760  const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin());
1761  const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd());
1762
1763  // If the resulting range is empty, just return std::nullopt.
1764  auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd);
1765  if (Ret.isEmpty(SE, /* IsSigned */ false))
1766    return std::nullopt;
1767  return Ret;
1768}
1769
1770PreservedAnalyses IRCEPass::run(Function &F, FunctionAnalysisManager &AM) {
1771  auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
1772  LoopInfo &LI = AM.getResult<LoopAnalysis>(F);
1773  // There are no loops in the function. Return before computing other expensive
1774  // analyses.
1775  if (LI.empty())
1776    return PreservedAnalyses::all();
1777  auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F);
1778  auto &BPI = AM.getResult<BranchProbabilityAnalysis>(F);
1779
1780  // Get BFI analysis result on demand. Please note that modification of
1781  // CFG invalidates this analysis and we should handle it.
1782  auto getBFI = [&F, &AM ]()->BlockFrequencyInfo & {
1783    return AM.getResult<BlockFrequencyAnalysis>(F);
1784  };
1785  InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI, { getBFI });
1786
1787  bool Changed = false;
1788  {
1789    bool CFGChanged = false;
1790    for (const auto &L : LI) {
1791      CFGChanged |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
1792                                 /*PreserveLCSSA=*/false);
1793      Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
1794    }
1795    Changed |= CFGChanged;
1796
1797    if (CFGChanged && !SkipProfitabilityChecks) {
1798      PreservedAnalyses PA = PreservedAnalyses::all();
1799      PA.abandon<BlockFrequencyAnalysis>();
1800      AM.invalidate(F, PA);
1801    }
1802  }
1803
1804  SmallPriorityWorklist<Loop *, 4> Worklist;
1805  appendLoopsToWorklist(LI, Worklist);
1806  auto LPMAddNewLoop = [&Worklist](Loop *NL, bool IsSubloop) {
1807    if (!IsSubloop)
1808      appendLoopsToWorklist(*NL, Worklist);
1809  };
1810
1811  while (!Worklist.empty()) {
1812    Loop *L = Worklist.pop_back_val();
1813    if (IRCE.run(L, LPMAddNewLoop)) {
1814      Changed = true;
1815      if (!SkipProfitabilityChecks) {
1816        PreservedAnalyses PA = PreservedAnalyses::all();
1817        PA.abandon<BlockFrequencyAnalysis>();
1818        AM.invalidate(F, PA);
1819      }
1820    }
1821  }
1822
1823  if (!Changed)
1824    return PreservedAnalyses::all();
1825  return getLoopPassPreservedAnalyses();
1826}
1827
1828bool IRCELegacyPass::runOnFunction(Function &F) {
1829  if (skipFunction(F))
1830    return false;
1831
1832  ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1833  BranchProbabilityInfo &BPI =
1834      getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
1835  auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1836  auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1837  InductiveRangeCheckElimination IRCE(SE, &BPI, DT, LI);
1838
1839  bool Changed = false;
1840
1841  for (const auto &L : LI) {
1842    Changed |= simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr,
1843                            /*PreserveLCSSA=*/false);
1844    Changed |= formLCSSARecursively(*L, DT, &LI, &SE);
1845  }
1846
1847  SmallPriorityWorklist<Loop *, 4> Worklist;
1848  appendLoopsToWorklist(LI, Worklist);
1849  auto LPMAddNewLoop = [&](Loop *NL, bool IsSubloop) {
1850    if (!IsSubloop)
1851      appendLoopsToWorklist(*NL, Worklist);
1852  };
1853
1854  while (!Worklist.empty()) {
1855    Loop *L = Worklist.pop_back_val();
1856    Changed |= IRCE.run(L, LPMAddNewLoop);
1857  }
1858  return Changed;
1859}
1860
1861bool
1862InductiveRangeCheckElimination::isProfitableToTransform(const Loop &L,
1863                                                        LoopStructure &LS) {
1864  if (SkipProfitabilityChecks)
1865    return true;
1866  if (GetBFI) {
1867    BlockFrequencyInfo &BFI = (*GetBFI)();
1868    uint64_t hFreq = BFI.getBlockFreq(LS.Header).getFrequency();
1869    uint64_t phFreq = BFI.getBlockFreq(L.getLoopPreheader()).getFrequency();
1870    if (phFreq != 0 && hFreq != 0 && (hFreq / phFreq < MinRuntimeIterations)) {
1871      LLVM_DEBUG(dbgs() << "irce: could not prove profitability: "
1872                        << "the estimated number of iterations basing on "
1873                           "frequency info is " << (hFreq / phFreq) << "\n";);
1874      return false;
1875    }
1876    return true;
1877  }
1878
1879  if (!BPI)
1880    return true;
1881  BranchProbability ExitProbability =
1882      BPI->getEdgeProbability(LS.Latch, LS.LatchBrExitIdx);
1883  if (ExitProbability > BranchProbability(1, MinRuntimeIterations)) {
1884    LLVM_DEBUG(dbgs() << "irce: could not prove profitability: "
1885                      << "the exit probability is too big " << ExitProbability
1886                      << "\n";);
1887    return false;
1888  }
1889  return true;
1890}
1891
1892bool InductiveRangeCheckElimination::run(
1893    Loop *L, function_ref<void(Loop *, bool)> LPMAddNewLoop) {
1894  if (L->getBlocks().size() >= LoopSizeCutoff) {
1895    LLVM_DEBUG(dbgs() << "irce: giving up constraining loop, too large\n");
1896    return false;
1897  }
1898
1899  BasicBlock *Preheader = L->getLoopPreheader();
1900  if (!Preheader) {
1901    LLVM_DEBUG(dbgs() << "irce: loop has no preheader, leaving\n");
1902    return false;
1903  }
1904
1905  LLVMContext &Context = Preheader->getContext();
1906  SmallVector<InductiveRangeCheck, 16> RangeChecks;
1907
1908  for (auto *BBI : L->getBlocks())
1909    if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator()))
1910      InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI,
1911                                                        RangeChecks);
1912
1913  if (RangeChecks.empty())
1914    return false;
1915
1916  auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) {
1917    OS << "irce: looking at loop "; L->print(OS);
1918    OS << "irce: loop has " << RangeChecks.size()
1919       << " inductive range checks: \n";
1920    for (InductiveRangeCheck &IRC : RangeChecks)
1921      IRC.print(OS);
1922  };
1923
1924  LLVM_DEBUG(PrintRecognizedRangeChecks(dbgs()));
1925
1926  if (PrintRangeChecks)
1927    PrintRecognizedRangeChecks(errs());
1928
1929  const char *FailureReason = nullptr;
1930  std::optional<LoopStructure> MaybeLoopStructure =
1931      LoopStructure::parseLoopStructure(SE, *L, FailureReason);
1932  if (!MaybeLoopStructure) {
1933    LLVM_DEBUG(dbgs() << "irce: could not parse loop structure: "
1934                      << FailureReason << "\n";);
1935    return false;
1936  }
1937  LoopStructure LS = *MaybeLoopStructure;
1938  if (!isProfitableToTransform(*L, LS))
1939    return false;
1940  const SCEVAddRecExpr *IndVar =
1941      cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep)));
1942
1943  std::optional<InductiveRangeCheck::Range> SafeIterRange;
1944  Instruction *ExprInsertPt = Preheader->getTerminator();
1945
1946  SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate;
1947  // Basing on the type of latch predicate, we interpret the IV iteration range
1948  // as signed or unsigned range. We use different min/max functions (signed or
1949  // unsigned) when intersecting this range with safe iteration ranges implied
1950  // by range checks.
1951  auto IntersectRange =
1952      LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange;
1953
1954  IRBuilder<> B(ExprInsertPt);
1955  for (InductiveRangeCheck &IRC : RangeChecks) {
1956    auto Result = IRC.computeSafeIterationSpace(SE, IndVar,
1957                                                LS.IsSignedPredicate);
1958    if (Result) {
1959      auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, *Result);
1960      if (MaybeSafeIterRange) {
1961        assert(!MaybeSafeIterRange->isEmpty(SE, LS.IsSignedPredicate) &&
1962               "We should never return empty ranges!");
1963        RangeChecksToEliminate.push_back(IRC);
1964        SafeIterRange = *MaybeSafeIterRange;
1965      }
1966    }
1967  }
1968
1969  if (!SafeIterRange)
1970    return false;
1971
1972  LoopConstrainer LC(*L, LI, LPMAddNewLoop, LS, SE, DT, *SafeIterRange);
1973  bool Changed = LC.run();
1974
1975  if (Changed) {
1976    auto PrintConstrainedLoopInfo = [L]() {
1977      dbgs() << "irce: in function ";
1978      dbgs() << L->getHeader()->getParent()->getName() << ": ";
1979      dbgs() << "constrained ";
1980      L->print(dbgs());
1981    };
1982
1983    LLVM_DEBUG(PrintConstrainedLoopInfo());
1984
1985    if (PrintChangedLoops)
1986      PrintConstrainedLoopInfo();
1987
1988    // Optimize away the now-redundant range checks.
1989
1990    for (InductiveRangeCheck &IRC : RangeChecksToEliminate) {
1991      ConstantInt *FoldedRangeCheck = IRC.getPassingDirection()
1992                                          ? ConstantInt::getTrue(Context)
1993                                          : ConstantInt::getFalse(Context);
1994      IRC.getCheckUse()->set(FoldedRangeCheck);
1995    }
1996  }
1997
1998  return Changed;
1999}
2000
2001Pass *llvm::createInductiveRangeCheckEliminationPass() {
2002  return new IRCELegacyPass();
2003}
2004