1//===- llvm/Analysis/ScalarEvolutionExpressions.h - SCEV Exprs --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the classes used to represent and build scalar expressions.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
14#define LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
15
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/Analysis/ScalarEvolution.h"
20#include "llvm/IR/Constants.h"
21#include "llvm/IR/ValueHandle.h"
22#include "llvm/Support/Casting.h"
23#include "llvm/Support/ErrorHandling.h"
24#include <cassert>
25#include <cstddef>
26
27namespace llvm {
28
29class APInt;
30class Constant;
31class ConstantInt;
32class ConstantRange;
33class Loop;
34class Type;
35class Value;
36
37enum SCEVTypes : unsigned short {
38  // These should be ordered in terms of increasing complexity to make the
39  // folders simpler.
40  scConstant,
41  scVScale,
42  scTruncate,
43  scZeroExtend,
44  scSignExtend,
45  scAddExpr,
46  scMulExpr,
47  scUDivExpr,
48  scAddRecExpr,
49  scUMaxExpr,
50  scSMaxExpr,
51  scUMinExpr,
52  scSMinExpr,
53  scSequentialUMinExpr,
54  scPtrToInt,
55  scUnknown,
56  scCouldNotCompute
57};
58
59/// This class represents a constant integer value.
60class SCEVConstant : public SCEV {
61  friend class ScalarEvolution;
62
63  ConstantInt *V;
64
65  SCEVConstant(const FoldingSetNodeIDRef ID, ConstantInt *v)
66      : SCEV(ID, scConstant, 1), V(v) {}
67
68public:
69  ConstantInt *getValue() const { return V; }
70  const APInt &getAPInt() const { return getValue()->getValue(); }
71
72  Type *getType() const { return V->getType(); }
73
74  /// Methods for support type inquiry through isa, cast, and dyn_cast:
75  static bool classof(const SCEV *S) { return S->getSCEVType() == scConstant; }
76};
77
78/// This class represents the value of vscale, as used when defining the length
79/// of a scalable vector or returned by the llvm.vscale() intrinsic.
80class SCEVVScale : public SCEV {
81  friend class ScalarEvolution;
82
83  SCEVVScale(const FoldingSetNodeIDRef ID, Type *ty)
84      : SCEV(ID, scVScale, 0), Ty(ty) {}
85
86  Type *Ty;
87
88public:
89  Type *getType() const { return Ty; }
90
91  /// Methods for support type inquiry through isa, cast, and dyn_cast:
92  static bool classof(const SCEV *S) { return S->getSCEVType() == scVScale; }
93};
94
95inline unsigned short computeExpressionSize(ArrayRef<const SCEV *> Args) {
96  APInt Size(16, 1);
97  for (const auto *Arg : Args)
98    Size = Size.uadd_sat(APInt(16, Arg->getExpressionSize()));
99  return (unsigned short)Size.getZExtValue();
100}
101
102/// This is the base class for unary cast operator classes.
103class SCEVCastExpr : public SCEV {
104protected:
105  const SCEV *Op;
106  Type *Ty;
107
108  SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy, const SCEV *op,
109               Type *ty);
110
111public:
112  const SCEV *getOperand() const { return Op; }
113  const SCEV *getOperand(unsigned i) const {
114    assert(i == 0 && "Operand index out of range!");
115    return Op;
116  }
117  ArrayRef<const SCEV *> operands() const { return Op; }
118  size_t getNumOperands() const { return 1; }
119  Type *getType() const { return Ty; }
120
121  /// Methods for support type inquiry through isa, cast, and dyn_cast:
122  static bool classof(const SCEV *S) {
123    return S->getSCEVType() == scPtrToInt || S->getSCEVType() == scTruncate ||
124           S->getSCEVType() == scZeroExtend || S->getSCEVType() == scSignExtend;
125  }
126};
127
128/// This class represents a cast from a pointer to a pointer-sized integer
129/// value.
130class SCEVPtrToIntExpr : public SCEVCastExpr {
131  friend class ScalarEvolution;
132
133  SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op, Type *ITy);
134
135public:
136  /// Methods for support type inquiry through isa, cast, and dyn_cast:
137  static bool classof(const SCEV *S) { return S->getSCEVType() == scPtrToInt; }
138};
139
140/// This is the base class for unary integral cast operator classes.
141class SCEVIntegralCastExpr : public SCEVCastExpr {
142protected:
143  SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
144                       const SCEV *op, Type *ty);
145
146public:
147  /// Methods for support type inquiry through isa, cast, and dyn_cast:
148  static bool classof(const SCEV *S) {
149    return S->getSCEVType() == scTruncate || S->getSCEVType() == scZeroExtend ||
150           S->getSCEVType() == scSignExtend;
151  }
152};
153
154/// This class represents a truncation of an integer value to a
155/// smaller integer value.
156class SCEVTruncateExpr : public SCEVIntegralCastExpr {
157  friend class ScalarEvolution;
158
159  SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
160
161public:
162  /// Methods for support type inquiry through isa, cast, and dyn_cast:
163  static bool classof(const SCEV *S) { return S->getSCEVType() == scTruncate; }
164};
165
166/// This class represents a zero extension of a small integer value
167/// to a larger integer value.
168class SCEVZeroExtendExpr : public SCEVIntegralCastExpr {
169  friend class ScalarEvolution;
170
171  SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
172
173public:
174  /// Methods for support type inquiry through isa, cast, and dyn_cast:
175  static bool classof(const SCEV *S) {
176    return S->getSCEVType() == scZeroExtend;
177  }
178};
179
180/// This class represents a sign extension of a small integer value
181/// to a larger integer value.
182class SCEVSignExtendExpr : public SCEVIntegralCastExpr {
183  friend class ScalarEvolution;
184
185  SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty);
186
187public:
188  /// Methods for support type inquiry through isa, cast, and dyn_cast:
189  static bool classof(const SCEV *S) {
190    return S->getSCEVType() == scSignExtend;
191  }
192};
193
194/// This node is a base class providing common functionality for
195/// n'ary operators.
196class SCEVNAryExpr : public SCEV {
197protected:
198  // Since SCEVs are immutable, ScalarEvolution allocates operand
199  // arrays with its SCEVAllocator, so this class just needs a simple
200  // pointer rather than a more elaborate vector-like data structure.
201  // This also avoids the need for a non-trivial destructor.
202  const SCEV *const *Operands;
203  size_t NumOperands;
204
205  SCEVNAryExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
206               const SCEV *const *O, size_t N)
207      : SCEV(ID, T, computeExpressionSize(ArrayRef(O, N))), Operands(O),
208        NumOperands(N) {}
209
210public:
211  size_t getNumOperands() const { return NumOperands; }
212
213  const SCEV *getOperand(unsigned i) const {
214    assert(i < NumOperands && "Operand index out of range!");
215    return Operands[i];
216  }
217
218  ArrayRef<const SCEV *> operands() const {
219    return ArrayRef(Operands, NumOperands);
220  }
221
222  NoWrapFlags getNoWrapFlags(NoWrapFlags Mask = NoWrapMask) const {
223    return (NoWrapFlags)(SubclassData & Mask);
224  }
225
226  bool hasNoUnsignedWrap() const {
227    return getNoWrapFlags(FlagNUW) != FlagAnyWrap;
228  }
229
230  bool hasNoSignedWrap() const {
231    return getNoWrapFlags(FlagNSW) != FlagAnyWrap;
232  }
233
234  bool hasNoSelfWrap() const { return getNoWrapFlags(FlagNW) != FlagAnyWrap; }
235
236  /// Methods for support type inquiry through isa, cast, and dyn_cast:
237  static bool classof(const SCEV *S) {
238    return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
239           S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
240           S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr ||
241           S->getSCEVType() == scSequentialUMinExpr ||
242           S->getSCEVType() == scAddRecExpr;
243  }
244};
245
246/// This node is the base class for n'ary commutative operators.
247class SCEVCommutativeExpr : public SCEVNAryExpr {
248protected:
249  SCEVCommutativeExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
250                      const SCEV *const *O, size_t N)
251      : SCEVNAryExpr(ID, T, O, N) {}
252
253public:
254  /// Methods for support type inquiry through isa, cast, and dyn_cast:
255  static bool classof(const SCEV *S) {
256    return S->getSCEVType() == scAddExpr || S->getSCEVType() == scMulExpr ||
257           S->getSCEVType() == scSMaxExpr || S->getSCEVType() == scUMaxExpr ||
258           S->getSCEVType() == scSMinExpr || S->getSCEVType() == scUMinExpr;
259  }
260
261  /// Set flags for a non-recurrence without clearing previously set flags.
262  void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
263};
264
265/// This node represents an addition of some number of SCEVs.
266class SCEVAddExpr : public SCEVCommutativeExpr {
267  friend class ScalarEvolution;
268
269  Type *Ty;
270
271  SCEVAddExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
272      : SCEVCommutativeExpr(ID, scAddExpr, O, N) {
273    auto *FirstPointerTypedOp = find_if(operands(), [](const SCEV *Op) {
274      return Op->getType()->isPointerTy();
275    });
276    if (FirstPointerTypedOp != operands().end())
277      Ty = (*FirstPointerTypedOp)->getType();
278    else
279      Ty = getOperand(0)->getType();
280  }
281
282public:
283  Type *getType() const { return Ty; }
284
285  /// Methods for support type inquiry through isa, cast, and dyn_cast:
286  static bool classof(const SCEV *S) { return S->getSCEVType() == scAddExpr; }
287};
288
289/// This node represents multiplication of some number of SCEVs.
290class SCEVMulExpr : public SCEVCommutativeExpr {
291  friend class ScalarEvolution;
292
293  SCEVMulExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
294      : SCEVCommutativeExpr(ID, scMulExpr, O, N) {}
295
296public:
297  Type *getType() const { return getOperand(0)->getType(); }
298
299  /// Methods for support type inquiry through isa, cast, and dyn_cast:
300  static bool classof(const SCEV *S) { return S->getSCEVType() == scMulExpr; }
301};
302
303/// This class represents a binary unsigned division operation.
304class SCEVUDivExpr : public SCEV {
305  friend class ScalarEvolution;
306
307  std::array<const SCEV *, 2> Operands;
308
309  SCEVUDivExpr(const FoldingSetNodeIDRef ID, const SCEV *lhs, const SCEV *rhs)
310      : SCEV(ID, scUDivExpr, computeExpressionSize({lhs, rhs})) {
311    Operands[0] = lhs;
312    Operands[1] = rhs;
313  }
314
315public:
316  const SCEV *getLHS() const { return Operands[0]; }
317  const SCEV *getRHS() const { return Operands[1]; }
318  size_t getNumOperands() const { return 2; }
319  const SCEV *getOperand(unsigned i) const {
320    assert((i == 0 || i == 1) && "Operand index out of range!");
321    return i == 0 ? getLHS() : getRHS();
322  }
323
324  ArrayRef<const SCEV *> operands() const { return Operands; }
325
326  Type *getType() const {
327    // In most cases the types of LHS and RHS will be the same, but in some
328    // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
329    // depend on the type for correctness, but handling types carefully can
330    // avoid extra casts in the SCEVExpander. The LHS is more likely to be
331    // a pointer type than the RHS, so use the RHS' type here.
332    return getRHS()->getType();
333  }
334
335  /// Methods for support type inquiry through isa, cast, and dyn_cast:
336  static bool classof(const SCEV *S) { return S->getSCEVType() == scUDivExpr; }
337};
338
339/// This node represents a polynomial recurrence on the trip count
340/// of the specified loop.  This is the primary focus of the
341/// ScalarEvolution framework; all the other SCEV subclasses are
342/// mostly just supporting infrastructure to allow SCEVAddRecExpr
343/// expressions to be created and analyzed.
344///
345/// All operands of an AddRec are required to be loop invariant.
346///
347class SCEVAddRecExpr : public SCEVNAryExpr {
348  friend class ScalarEvolution;
349
350  const Loop *L;
351
352  SCEVAddRecExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N,
353                 const Loop *l)
354      : SCEVNAryExpr(ID, scAddRecExpr, O, N), L(l) {}
355
356public:
357  Type *getType() const { return getStart()->getType(); }
358  const SCEV *getStart() const { return Operands[0]; }
359  const Loop *getLoop() const { return L; }
360
361  /// Constructs and returns the recurrence indicating how much this
362  /// expression steps by.  If this is a polynomial of degree N, it
363  /// returns a chrec of degree N-1.  We cannot determine whether
364  /// the step recurrence has self-wraparound.
365  const SCEV *getStepRecurrence(ScalarEvolution &SE) const {
366    if (isAffine())
367      return getOperand(1);
368    return SE.getAddRecExpr(
369        SmallVector<const SCEV *, 3>(operands().drop_front()), getLoop(),
370        FlagAnyWrap);
371  }
372
373  /// Return true if this represents an expression A + B*x where A
374  /// and B are loop invariant values.
375  bool isAffine() const {
376    // We know that the start value is invariant.  This expression is thus
377    // affine iff the step is also invariant.
378    return getNumOperands() == 2;
379  }
380
381  /// Return true if this represents an expression A + B*x + C*x^2
382  /// where A, B and C are loop invariant values.  This corresponds
383  /// to an addrec of the form {L,+,M,+,N}
384  bool isQuadratic() const { return getNumOperands() == 3; }
385
386  /// Set flags for a recurrence without clearing any previously set flags.
387  /// For AddRec, either NUW or NSW implies NW. Keep track of this fact here
388  /// to make it easier to propagate flags.
389  void setNoWrapFlags(NoWrapFlags Flags) {
390    if (Flags & (FlagNUW | FlagNSW))
391      Flags = ScalarEvolution::setFlags(Flags, FlagNW);
392    SubclassData |= Flags;
393  }
394
395  /// Return the value of this chain of recurrences at the specified
396  /// iteration number.
397  const SCEV *evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const;
398
399  /// Return the value of this chain of recurrences at the specified iteration
400  /// number. Takes an explicit list of operands to represent an AddRec.
401  static const SCEV *evaluateAtIteration(ArrayRef<const SCEV *> Operands,
402                                         const SCEV *It, ScalarEvolution &SE);
403
404  /// Return the number of iterations of this loop that produce
405  /// values in the specified constant range.  Another way of
406  /// looking at this is that it returns the first iteration number
407  /// where the value is not in the condition, thus computing the
408  /// exit count.  If the iteration count can't be computed, an
409  /// instance of SCEVCouldNotCompute is returned.
410  const SCEV *getNumIterationsInRange(const ConstantRange &Range,
411                                      ScalarEvolution &SE) const;
412
413  /// Return an expression representing the value of this expression
414  /// one iteration of the loop ahead.
415  const SCEVAddRecExpr *getPostIncExpr(ScalarEvolution &SE) const;
416
417  /// Methods for support type inquiry through isa, cast, and dyn_cast:
418  static bool classof(const SCEV *S) {
419    return S->getSCEVType() == scAddRecExpr;
420  }
421};
422
423/// This node is the base class min/max selections.
424class SCEVMinMaxExpr : public SCEVCommutativeExpr {
425  friend class ScalarEvolution;
426
427  static bool isMinMaxType(enum SCEVTypes T) {
428    return T == scSMaxExpr || T == scUMaxExpr || T == scSMinExpr ||
429           T == scUMinExpr;
430  }
431
432protected:
433  /// Note: Constructing subclasses via this constructor is allowed
434  SCEVMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
435                 const SCEV *const *O, size_t N)
436      : SCEVCommutativeExpr(ID, T, O, N) {
437    assert(isMinMaxType(T));
438    // Min and max never overflow
439    setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
440  }
441
442public:
443  Type *getType() const { return getOperand(0)->getType(); }
444
445  static bool classof(const SCEV *S) { return isMinMaxType(S->getSCEVType()); }
446
447  static enum SCEVTypes negate(enum SCEVTypes T) {
448    switch (T) {
449    case scSMaxExpr:
450      return scSMinExpr;
451    case scSMinExpr:
452      return scSMaxExpr;
453    case scUMaxExpr:
454      return scUMinExpr;
455    case scUMinExpr:
456      return scUMaxExpr;
457    default:
458      llvm_unreachable("Not a min or max SCEV type!");
459    }
460  }
461};
462
463/// This class represents a signed maximum selection.
464class SCEVSMaxExpr : public SCEVMinMaxExpr {
465  friend class ScalarEvolution;
466
467  SCEVSMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
468      : SCEVMinMaxExpr(ID, scSMaxExpr, O, N) {}
469
470public:
471  /// Methods for support type inquiry through isa, cast, and dyn_cast:
472  static bool classof(const SCEV *S) { return S->getSCEVType() == scSMaxExpr; }
473};
474
475/// This class represents an unsigned maximum selection.
476class SCEVUMaxExpr : public SCEVMinMaxExpr {
477  friend class ScalarEvolution;
478
479  SCEVUMaxExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
480      : SCEVMinMaxExpr(ID, scUMaxExpr, O, N) {}
481
482public:
483  /// Methods for support type inquiry through isa, cast, and dyn_cast:
484  static bool classof(const SCEV *S) { return S->getSCEVType() == scUMaxExpr; }
485};
486
487/// This class represents a signed minimum selection.
488class SCEVSMinExpr : public SCEVMinMaxExpr {
489  friend class ScalarEvolution;
490
491  SCEVSMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
492      : SCEVMinMaxExpr(ID, scSMinExpr, O, N) {}
493
494public:
495  /// Methods for support type inquiry through isa, cast, and dyn_cast:
496  static bool classof(const SCEV *S) { return S->getSCEVType() == scSMinExpr; }
497};
498
499/// This class represents an unsigned minimum selection.
500class SCEVUMinExpr : public SCEVMinMaxExpr {
501  friend class ScalarEvolution;
502
503  SCEVUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O, size_t N)
504      : SCEVMinMaxExpr(ID, scUMinExpr, O, N) {}
505
506public:
507  /// Methods for support type inquiry through isa, cast, and dyn_cast:
508  static bool classof(const SCEV *S) { return S->getSCEVType() == scUMinExpr; }
509};
510
511/// This node is the base class for sequential/in-order min/max selections.
512/// Note that their fundamental difference from SCEVMinMaxExpr's is that they
513/// are early-returning upon reaching saturation point.
514/// I.e. given `0 umin_seq poison`, the result will be `0`,
515/// while the result of `0 umin poison` is `poison`.
516class SCEVSequentialMinMaxExpr : public SCEVNAryExpr {
517  friend class ScalarEvolution;
518
519  static bool isSequentialMinMaxType(enum SCEVTypes T) {
520    return T == scSequentialUMinExpr;
521  }
522
523  /// Set flags for a non-recurrence without clearing previously set flags.
524  void setNoWrapFlags(NoWrapFlags Flags) { SubclassData |= Flags; }
525
526protected:
527  /// Note: Constructing subclasses via this constructor is allowed
528  SCEVSequentialMinMaxExpr(const FoldingSetNodeIDRef ID, enum SCEVTypes T,
529                           const SCEV *const *O, size_t N)
530      : SCEVNAryExpr(ID, T, O, N) {
531    assert(isSequentialMinMaxType(T));
532    // Min and max never overflow
533    setNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW));
534  }
535
536public:
537  Type *getType() const { return getOperand(0)->getType(); }
538
539  static SCEVTypes getEquivalentNonSequentialSCEVType(SCEVTypes Ty) {
540    assert(isSequentialMinMaxType(Ty));
541    switch (Ty) {
542    case scSequentialUMinExpr:
543      return scUMinExpr;
544    default:
545      llvm_unreachable("Not a sequential min/max type.");
546    }
547  }
548
549  SCEVTypes getEquivalentNonSequentialSCEVType() const {
550    return getEquivalentNonSequentialSCEVType(getSCEVType());
551  }
552
553  static bool classof(const SCEV *S) {
554    return isSequentialMinMaxType(S->getSCEVType());
555  }
556};
557
558/// This class represents a sequential/in-order unsigned minimum selection.
559class SCEVSequentialUMinExpr : public SCEVSequentialMinMaxExpr {
560  friend class ScalarEvolution;
561
562  SCEVSequentialUMinExpr(const FoldingSetNodeIDRef ID, const SCEV *const *O,
563                         size_t N)
564      : SCEVSequentialMinMaxExpr(ID, scSequentialUMinExpr, O, N) {}
565
566public:
567  /// Methods for support type inquiry through isa, cast, and dyn_cast:
568  static bool classof(const SCEV *S) {
569    return S->getSCEVType() == scSequentialUMinExpr;
570  }
571};
572
573/// This means that we are dealing with an entirely unknown SCEV
574/// value, and only represent it as its LLVM Value.  This is the
575/// "bottom" value for the analysis.
576class SCEVUnknown final : public SCEV, private CallbackVH {
577  friend class ScalarEvolution;
578
579  /// The parent ScalarEvolution value. This is used to update the
580  /// parent's maps when the value associated with a SCEVUnknown is
581  /// deleted or RAUW'd.
582  ScalarEvolution *SE;
583
584  /// The next pointer in the linked list of all SCEVUnknown
585  /// instances owned by a ScalarEvolution.
586  SCEVUnknown *Next;
587
588  SCEVUnknown(const FoldingSetNodeIDRef ID, Value *V, ScalarEvolution *se,
589              SCEVUnknown *next)
590      : SCEV(ID, scUnknown, 1), CallbackVH(V), SE(se), Next(next) {}
591
592  // Implement CallbackVH.
593  void deleted() override;
594  void allUsesReplacedWith(Value *New) override;
595
596public:
597  Value *getValue() const { return getValPtr(); }
598
599  Type *getType() const { return getValPtr()->getType(); }
600
601  /// Methods for support type inquiry through isa, cast, and dyn_cast:
602  static bool classof(const SCEV *S) { return S->getSCEVType() == scUnknown; }
603};
604
605/// This class defines a simple visitor class that may be used for
606/// various SCEV analysis purposes.
607template <typename SC, typename RetVal = void> struct SCEVVisitor {
608  RetVal visit(const SCEV *S) {
609    switch (S->getSCEVType()) {
610    case scConstant:
611      return ((SC *)this)->visitConstant((const SCEVConstant *)S);
612    case scVScale:
613      return ((SC *)this)->visitVScale((const SCEVVScale *)S);
614    case scPtrToInt:
615      return ((SC *)this)->visitPtrToIntExpr((const SCEVPtrToIntExpr *)S);
616    case scTruncate:
617      return ((SC *)this)->visitTruncateExpr((const SCEVTruncateExpr *)S);
618    case scZeroExtend:
619      return ((SC *)this)->visitZeroExtendExpr((const SCEVZeroExtendExpr *)S);
620    case scSignExtend:
621      return ((SC *)this)->visitSignExtendExpr((const SCEVSignExtendExpr *)S);
622    case scAddExpr:
623      return ((SC *)this)->visitAddExpr((const SCEVAddExpr *)S);
624    case scMulExpr:
625      return ((SC *)this)->visitMulExpr((const SCEVMulExpr *)S);
626    case scUDivExpr:
627      return ((SC *)this)->visitUDivExpr((const SCEVUDivExpr *)S);
628    case scAddRecExpr:
629      return ((SC *)this)->visitAddRecExpr((const SCEVAddRecExpr *)S);
630    case scSMaxExpr:
631      return ((SC *)this)->visitSMaxExpr((const SCEVSMaxExpr *)S);
632    case scUMaxExpr:
633      return ((SC *)this)->visitUMaxExpr((const SCEVUMaxExpr *)S);
634    case scSMinExpr:
635      return ((SC *)this)->visitSMinExpr((const SCEVSMinExpr *)S);
636    case scUMinExpr:
637      return ((SC *)this)->visitUMinExpr((const SCEVUMinExpr *)S);
638    case scSequentialUMinExpr:
639      return ((SC *)this)
640          ->visitSequentialUMinExpr((const SCEVSequentialUMinExpr *)S);
641    case scUnknown:
642      return ((SC *)this)->visitUnknown((const SCEVUnknown *)S);
643    case scCouldNotCompute:
644      return ((SC *)this)->visitCouldNotCompute((const SCEVCouldNotCompute *)S);
645    }
646    llvm_unreachable("Unknown SCEV kind!");
647  }
648
649  RetVal visitCouldNotCompute(const SCEVCouldNotCompute *S) {
650    llvm_unreachable("Invalid use of SCEVCouldNotCompute!");
651  }
652};
653
654/// Visit all nodes in the expression tree using worklist traversal.
655///
656/// Visitor implements:
657///   // return true to follow this node.
658///   bool follow(const SCEV *S);
659///   // return true to terminate the search.
660///   bool isDone();
661template <typename SV> class SCEVTraversal {
662  SV &Visitor;
663  SmallVector<const SCEV *, 8> Worklist;
664  SmallPtrSet<const SCEV *, 8> Visited;
665
666  void push(const SCEV *S) {
667    if (Visited.insert(S).second && Visitor.follow(S))
668      Worklist.push_back(S);
669  }
670
671public:
672  SCEVTraversal(SV &V) : Visitor(V) {}
673
674  void visitAll(const SCEV *Root) {
675    push(Root);
676    while (!Worklist.empty() && !Visitor.isDone()) {
677      const SCEV *S = Worklist.pop_back_val();
678
679      switch (S->getSCEVType()) {
680      case scConstant:
681      case scVScale:
682      case scUnknown:
683        continue;
684      case scPtrToInt:
685      case scTruncate:
686      case scZeroExtend:
687      case scSignExtend:
688      case scAddExpr:
689      case scMulExpr:
690      case scUDivExpr:
691      case scSMaxExpr:
692      case scUMaxExpr:
693      case scSMinExpr:
694      case scUMinExpr:
695      case scSequentialUMinExpr:
696      case scAddRecExpr:
697        for (const auto *Op : S->operands()) {
698          push(Op);
699          if (Visitor.isDone())
700            break;
701        }
702        continue;
703      case scCouldNotCompute:
704        llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
705      }
706      llvm_unreachable("Unknown SCEV kind!");
707    }
708  }
709};
710
711/// Use SCEVTraversal to visit all nodes in the given expression tree.
712template <typename SV> void visitAll(const SCEV *Root, SV &Visitor) {
713  SCEVTraversal<SV> T(Visitor);
714  T.visitAll(Root);
715}
716
717/// Return true if any node in \p Root satisfies the predicate \p Pred.
718template <typename PredTy>
719bool SCEVExprContains(const SCEV *Root, PredTy Pred) {
720  struct FindClosure {
721    bool Found = false;
722    PredTy Pred;
723
724    FindClosure(PredTy Pred) : Pred(Pred) {}
725
726    bool follow(const SCEV *S) {
727      if (!Pred(S))
728        return true;
729
730      Found = true;
731      return false;
732    }
733
734    bool isDone() const { return Found; }
735  };
736
737  FindClosure FC(Pred);
738  visitAll(Root, FC);
739  return FC.Found;
740}
741
742/// This visitor recursively visits a SCEV expression and re-writes it.
743/// The result from each visit is cached, so it will return the same
744/// SCEV for the same input.
745template <typename SC>
746class SCEVRewriteVisitor : public SCEVVisitor<SC, const SCEV *> {
747protected:
748  ScalarEvolution &SE;
749  // Memoize the result of each visit so that we only compute once for
750  // the same input SCEV. This is to avoid redundant computations when
751  // a SCEV is referenced by multiple SCEVs. Without memoization, this
752  // visit algorithm would have exponential time complexity in the worst
753  // case, causing the compiler to hang on certain tests.
754  SmallDenseMap<const SCEV *, const SCEV *> RewriteResults;
755
756public:
757  SCEVRewriteVisitor(ScalarEvolution &SE) : SE(SE) {}
758
759  const SCEV *visit(const SCEV *S) {
760    auto It = RewriteResults.find(S);
761    if (It != RewriteResults.end())
762      return It->second;
763    auto *Visited = SCEVVisitor<SC, const SCEV *>::visit(S);
764    auto Result = RewriteResults.try_emplace(S, Visited);
765    assert(Result.second && "Should insert a new entry");
766    return Result.first->second;
767  }
768
769  const SCEV *visitConstant(const SCEVConstant *Constant) { return Constant; }
770
771  const SCEV *visitVScale(const SCEVVScale *VScale) { return VScale; }
772
773  const SCEV *visitPtrToIntExpr(const SCEVPtrToIntExpr *Expr) {
774    const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
775    return Operand == Expr->getOperand()
776               ? Expr
777               : SE.getPtrToIntExpr(Operand, Expr->getType());
778  }
779
780  const SCEV *visitTruncateExpr(const SCEVTruncateExpr *Expr) {
781    const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
782    return Operand == Expr->getOperand()
783               ? Expr
784               : SE.getTruncateExpr(Operand, Expr->getType());
785  }
786
787  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
788    const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
789    return Operand == Expr->getOperand()
790               ? Expr
791               : SE.getZeroExtendExpr(Operand, Expr->getType());
792  }
793
794  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
795    const SCEV *Operand = ((SC *)this)->visit(Expr->getOperand());
796    return Operand == Expr->getOperand()
797               ? Expr
798               : SE.getSignExtendExpr(Operand, Expr->getType());
799  }
800
801  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
802    SmallVector<const SCEV *, 2> Operands;
803    bool Changed = false;
804    for (const auto *Op : Expr->operands()) {
805      Operands.push_back(((SC *)this)->visit(Op));
806      Changed |= Op != Operands.back();
807    }
808    return !Changed ? Expr : SE.getAddExpr(Operands);
809  }
810
811  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
812    SmallVector<const SCEV *, 2> Operands;
813    bool Changed = false;
814    for (const auto *Op : Expr->operands()) {
815      Operands.push_back(((SC *)this)->visit(Op));
816      Changed |= Op != Operands.back();
817    }
818    return !Changed ? Expr : SE.getMulExpr(Operands);
819  }
820
821  const SCEV *visitUDivExpr(const SCEVUDivExpr *Expr) {
822    auto *LHS = ((SC *)this)->visit(Expr->getLHS());
823    auto *RHS = ((SC *)this)->visit(Expr->getRHS());
824    bool Changed = LHS != Expr->getLHS() || RHS != Expr->getRHS();
825    return !Changed ? Expr : SE.getUDivExpr(LHS, RHS);
826  }
827
828  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
829    SmallVector<const SCEV *, 2> Operands;
830    bool Changed = false;
831    for (const auto *Op : Expr->operands()) {
832      Operands.push_back(((SC *)this)->visit(Op));
833      Changed |= Op != Operands.back();
834    }
835    return !Changed ? Expr
836                    : SE.getAddRecExpr(Operands, Expr->getLoop(),
837                                       Expr->getNoWrapFlags());
838  }
839
840  const SCEV *visitSMaxExpr(const SCEVSMaxExpr *Expr) {
841    SmallVector<const SCEV *, 2> Operands;
842    bool Changed = false;
843    for (const auto *Op : Expr->operands()) {
844      Operands.push_back(((SC *)this)->visit(Op));
845      Changed |= Op != Operands.back();
846    }
847    return !Changed ? Expr : SE.getSMaxExpr(Operands);
848  }
849
850  const SCEV *visitUMaxExpr(const SCEVUMaxExpr *Expr) {
851    SmallVector<const SCEV *, 2> Operands;
852    bool Changed = false;
853    for (const auto *Op : Expr->operands()) {
854      Operands.push_back(((SC *)this)->visit(Op));
855      Changed |= Op != Operands.back();
856    }
857    return !Changed ? Expr : SE.getUMaxExpr(Operands);
858  }
859
860  const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
861    SmallVector<const SCEV *, 2> Operands;
862    bool Changed = false;
863    for (const auto *Op : Expr->operands()) {
864      Operands.push_back(((SC *)this)->visit(Op));
865      Changed |= Op != Operands.back();
866    }
867    return !Changed ? Expr : SE.getSMinExpr(Operands);
868  }
869
870  const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
871    SmallVector<const SCEV *, 2> Operands;
872    bool Changed = false;
873    for (const auto *Op : Expr->operands()) {
874      Operands.push_back(((SC *)this)->visit(Op));
875      Changed |= Op != Operands.back();
876    }
877    return !Changed ? Expr : SE.getUMinExpr(Operands);
878  }
879
880  const SCEV *visitSequentialUMinExpr(const SCEVSequentialUMinExpr *Expr) {
881    SmallVector<const SCEV *, 2> Operands;
882    bool Changed = false;
883    for (const auto *Op : Expr->operands()) {
884      Operands.push_back(((SC *)this)->visit(Op));
885      Changed |= Op != Operands.back();
886    }
887    return !Changed ? Expr : SE.getUMinExpr(Operands, /*Sequential=*/true);
888  }
889
890  const SCEV *visitUnknown(const SCEVUnknown *Expr) { return Expr; }
891
892  const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {
893    return Expr;
894  }
895};
896
897using ValueToValueMap = DenseMap<const Value *, Value *>;
898using ValueToSCEVMapTy = DenseMap<const Value *, const SCEV *>;
899
900/// The SCEVParameterRewriter takes a scalar evolution expression and updates
901/// the SCEVUnknown components following the Map (Value -> SCEV).
902class SCEVParameterRewriter : public SCEVRewriteVisitor<SCEVParameterRewriter> {
903public:
904  static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE,
905                             ValueToSCEVMapTy &Map) {
906    SCEVParameterRewriter Rewriter(SE, Map);
907    return Rewriter.visit(Scev);
908  }
909
910  SCEVParameterRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
911      : SCEVRewriteVisitor(SE), Map(M) {}
912
913  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
914    auto I = Map.find(Expr->getValue());
915    if (I == Map.end())
916      return Expr;
917    return I->second;
918  }
919
920private:
921  ValueToSCEVMapTy &Map;
922};
923
924using LoopToScevMapT = DenseMap<const Loop *, const SCEV *>;
925
926/// The SCEVLoopAddRecRewriter takes a scalar evolution expression and applies
927/// the Map (Loop -> SCEV) to all AddRecExprs.
928class SCEVLoopAddRecRewriter
929    : public SCEVRewriteVisitor<SCEVLoopAddRecRewriter> {
930public:
931  SCEVLoopAddRecRewriter(ScalarEvolution &SE, LoopToScevMapT &M)
932      : SCEVRewriteVisitor(SE), Map(M) {}
933
934  static const SCEV *rewrite(const SCEV *Scev, LoopToScevMapT &Map,
935                             ScalarEvolution &SE) {
936    SCEVLoopAddRecRewriter Rewriter(SE, Map);
937    return Rewriter.visit(Scev);
938  }
939
940  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
941    SmallVector<const SCEV *, 2> Operands;
942    for (const SCEV *Op : Expr->operands())
943      Operands.push_back(visit(Op));
944
945    const Loop *L = Expr->getLoop();
946    if (0 == Map.count(L))
947      return SE.getAddRecExpr(Operands, L, Expr->getNoWrapFlags());
948
949    return SCEVAddRecExpr::evaluateAtIteration(Operands, Map[L], SE);
950  }
951
952private:
953  LoopToScevMapT &Map;
954};
955
956} // end namespace llvm
957
958#endif // LLVM_ANALYSIS_SCALAREVOLUTIONEXPRESSIONS_H
959