ScalarEvolution.cpp revision 223017
1193323Sed//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2193323Sed//
3193323Sed//                     The LLVM Compiler Infrastructure
4193323Sed//
5193323Sed// This file is distributed under the University of Illinois Open Source
6193323Sed// License. See LICENSE.TXT for details.
7193323Sed//
8193323Sed//===----------------------------------------------------------------------===//
9193323Sed//
10193323Sed// This file contains the implementation of the scalar evolution analysis
11193323Sed// engine, which is used primarily to analyze expressions involving induction
12193323Sed// variables in loops.
13193323Sed//
14193323Sed// There are several aspects to this library.  First is the representation of
15193323Sed// scalar expressions, which are represented as subclasses of the SCEV class.
16193323Sed// These classes are used to represent certain types of subexpressions that we
17198090Srdivacky// can handle. We only create one SCEV of a particular shape, so
18198090Srdivacky// pointer-comparisons for equality are legal.
19193323Sed//
20193323Sed// One important aspect of the SCEV objects is that they are never cyclic, even
21193323Sed// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
22193323Sed// the PHI node is one of the idioms that we can represent (e.g., a polynomial
23193323Sed// recurrence) then we represent it directly as a recurrence node, otherwise we
24193323Sed// represent it as a SCEVUnknown node.
25193323Sed//
26193323Sed// In addition to being able to represent expressions of various types, we also
27193323Sed// have folders that are used to build the *canonical* representation for a
28193323Sed// particular expression.  These folders are capable of using a variety of
29193323Sed// rewrite rules to simplify the expressions.
30193323Sed//
31193323Sed// Once the folders are defined, we can implement the more interesting
32193323Sed// higher-level code, such as the code that recognizes PHI nodes of various
33193323Sed// types, computes the execution count of a loop, etc.
34193323Sed//
35193323Sed// TODO: We should use these routines and value representations to implement
36193323Sed// dependence analysis!
37193323Sed//
38193323Sed//===----------------------------------------------------------------------===//
39193323Sed//
40193323Sed// There are several good references for the techniques used in this analysis.
41193323Sed//
42193323Sed//  Chains of recurrences -- a method to expedite the evaluation
43193323Sed//  of closed-form functions
44193323Sed//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
45193323Sed//
46193323Sed//  On computational properties of chains of recurrences
47193323Sed//  Eugene V. Zima
48193323Sed//
49193323Sed//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
50193323Sed//  Robert A. van Engelen
51193323Sed//
52193323Sed//  Efficient Symbolic Analysis for Optimizing Compilers
53193323Sed//  Robert A. van Engelen
54193323Sed//
55193323Sed//  Using the chains of recurrences algebra for data dependence testing and
56193323Sed//  induction variable substitution
57193323Sed//  MS Thesis, Johnie Birch
58193323Sed//
59193323Sed//===----------------------------------------------------------------------===//
60193323Sed
61193323Sed#define DEBUG_TYPE "scalar-evolution"
62193323Sed#include "llvm/Analysis/ScalarEvolutionExpressions.h"
63193323Sed#include "llvm/Constants.h"
64193323Sed#include "llvm/DerivedTypes.h"
65193323Sed#include "llvm/GlobalVariable.h"
66198090Srdivacky#include "llvm/GlobalAlias.h"
67193323Sed#include "llvm/Instructions.h"
68198090Srdivacky#include "llvm/LLVMContext.h"
69198090Srdivacky#include "llvm/Operator.h"
70193323Sed#include "llvm/Analysis/ConstantFolding.h"
71193323Sed#include "llvm/Analysis/Dominators.h"
72218893Sdim#include "llvm/Analysis/InstructionSimplify.h"
73193323Sed#include "llvm/Analysis/LoopInfo.h"
74194612Sed#include "llvm/Analysis/ValueTracking.h"
75193323Sed#include "llvm/Assembly/Writer.h"
76193323Sed#include "llvm/Target/TargetData.h"
77193323Sed#include "llvm/Support/CommandLine.h"
78193323Sed#include "llvm/Support/ConstantRange.h"
79201360Srdivacky#include "llvm/Support/Debug.h"
80198090Srdivacky#include "llvm/Support/ErrorHandling.h"
81193323Sed#include "llvm/Support/GetElementPtrTypeIterator.h"
82193323Sed#include "llvm/Support/InstIterator.h"
83193323Sed#include "llvm/Support/MathExtras.h"
84193323Sed#include "llvm/Support/raw_ostream.h"
85193323Sed#include "llvm/ADT/Statistic.h"
86193323Sed#include "llvm/ADT/STLExtras.h"
87198090Srdivacky#include "llvm/ADT/SmallPtrSet.h"
88193323Sed#include <algorithm>
89193323Sedusing namespace llvm;
90193323Sed
91193323SedSTATISTIC(NumArrayLenItCounts,
92193323Sed          "Number of trip counts computed with array length");
93193323SedSTATISTIC(NumTripCountsComputed,
94193323Sed          "Number of loops with predictable loop counts");
95193323SedSTATISTIC(NumTripCountsNotComputed,
96193323Sed          "Number of loops without predictable loop counts");
97193323SedSTATISTIC(NumBruteForceTripCountsComputed,
98193323Sed          "Number of loops with trip counts computed by force");
99193323Sed
100193323Sedstatic cl::opt<unsigned>
101193323SedMaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
102193323Sed                        cl::desc("Maximum number of iterations SCEV will "
103195098Sed                                 "symbolically execute a constant "
104195098Sed                                 "derived loop"),
105193323Sed                        cl::init(100));
106193323Sed
107218893SdimINITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution",
108218893Sdim                "Scalar Evolution Analysis", false, true)
109218893SdimINITIALIZE_PASS_DEPENDENCY(LoopInfo)
110218893SdimINITIALIZE_PASS_DEPENDENCY(DominatorTree)
111218893SdimINITIALIZE_PASS_END(ScalarEvolution, "scalar-evolution",
112218893Sdim                "Scalar Evolution Analysis", false, true)
113193323Sedchar ScalarEvolution::ID = 0;
114193323Sed
115193323Sed//===----------------------------------------------------------------------===//
116193323Sed//                           SCEV class definitions
117193323Sed//===----------------------------------------------------------------------===//
118193323Sed
119193323Sed//===----------------------------------------------------------------------===//
120193323Sed// Implementation of the SCEV class.
121193323Sed//
122195340Sed
123193323Sedvoid SCEV::dump() const {
124201360Srdivacky  print(dbgs());
125201360Srdivacky  dbgs() << '\n';
126193323Sed}
127193323Sed
128218893Sdimvoid SCEV::print(raw_ostream &OS) const {
129218893Sdim  switch (getSCEVType()) {
130218893Sdim  case scConstant:
131218893Sdim    WriteAsOperand(OS, cast<SCEVConstant>(this)->getValue(), false);
132218893Sdim    return;
133218893Sdim  case scTruncate: {
134218893Sdim    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
135218893Sdim    const SCEV *Op = Trunc->getOperand();
136218893Sdim    OS << "(trunc " << *Op->getType() << " " << *Op << " to "
137218893Sdim       << *Trunc->getType() << ")";
138218893Sdim    return;
139218893Sdim  }
140218893Sdim  case scZeroExtend: {
141218893Sdim    const SCEVZeroExtendExpr *ZExt = cast<SCEVZeroExtendExpr>(this);
142218893Sdim    const SCEV *Op = ZExt->getOperand();
143218893Sdim    OS << "(zext " << *Op->getType() << " " << *Op << " to "
144218893Sdim       << *ZExt->getType() << ")";
145218893Sdim    return;
146218893Sdim  }
147218893Sdim  case scSignExtend: {
148218893Sdim    const SCEVSignExtendExpr *SExt = cast<SCEVSignExtendExpr>(this);
149218893Sdim    const SCEV *Op = SExt->getOperand();
150218893Sdim    OS << "(sext " << *Op->getType() << " " << *Op << " to "
151218893Sdim       << *SExt->getType() << ")";
152218893Sdim    return;
153218893Sdim  }
154218893Sdim  case scAddRecExpr: {
155218893Sdim    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(this);
156218893Sdim    OS << "{" << *AR->getOperand(0);
157218893Sdim    for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i)
158218893Sdim      OS << ",+," << *AR->getOperand(i);
159218893Sdim    OS << "}<";
160221345Sdim    if (AR->getNoWrapFlags(FlagNUW))
161218893Sdim      OS << "nuw><";
162221345Sdim    if (AR->getNoWrapFlags(FlagNSW))
163218893Sdim      OS << "nsw><";
164221345Sdim    if (AR->getNoWrapFlags(FlagNW) &&
165221345Sdim        !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW)))
166221345Sdim      OS << "nw><";
167218893Sdim    WriteAsOperand(OS, AR->getLoop()->getHeader(), /*PrintType=*/false);
168218893Sdim    OS << ">";
169218893Sdim    return;
170218893Sdim  }
171218893Sdim  case scAddExpr:
172218893Sdim  case scMulExpr:
173218893Sdim  case scUMaxExpr:
174218893Sdim  case scSMaxExpr: {
175218893Sdim    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(this);
176218893Sdim    const char *OpStr = 0;
177218893Sdim    switch (NAry->getSCEVType()) {
178218893Sdim    case scAddExpr: OpStr = " + "; break;
179218893Sdim    case scMulExpr: OpStr = " * "; break;
180218893Sdim    case scUMaxExpr: OpStr = " umax "; break;
181218893Sdim    case scSMaxExpr: OpStr = " smax "; break;
182218893Sdim    }
183218893Sdim    OS << "(";
184218893Sdim    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
185218893Sdim         I != E; ++I) {
186218893Sdim      OS << **I;
187218893Sdim      if (llvm::next(I) != E)
188218893Sdim        OS << OpStr;
189218893Sdim    }
190218893Sdim    OS << ")";
191218893Sdim    return;
192218893Sdim  }
193218893Sdim  case scUDivExpr: {
194218893Sdim    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(this);
195218893Sdim    OS << "(" << *UDiv->getLHS() << " /u " << *UDiv->getRHS() << ")";
196218893Sdim    return;
197218893Sdim  }
198218893Sdim  case scUnknown: {
199218893Sdim    const SCEVUnknown *U = cast<SCEVUnknown>(this);
200218893Sdim    const Type *AllocTy;
201218893Sdim    if (U->isSizeOf(AllocTy)) {
202218893Sdim      OS << "sizeof(" << *AllocTy << ")";
203218893Sdim      return;
204218893Sdim    }
205218893Sdim    if (U->isAlignOf(AllocTy)) {
206218893Sdim      OS << "alignof(" << *AllocTy << ")";
207218893Sdim      return;
208218893Sdim    }
209221345Sdim
210218893Sdim    const Type *CTy;
211218893Sdim    Constant *FieldNo;
212218893Sdim    if (U->isOffsetOf(CTy, FieldNo)) {
213218893Sdim      OS << "offsetof(" << *CTy << ", ";
214218893Sdim      WriteAsOperand(OS, FieldNo, false);
215218893Sdim      OS << ")";
216218893Sdim      return;
217218893Sdim    }
218221345Sdim
219218893Sdim    // Otherwise just print it normally.
220218893Sdim    WriteAsOperand(OS, U->getValue(), false);
221218893Sdim    return;
222218893Sdim  }
223218893Sdim  case scCouldNotCompute:
224218893Sdim    OS << "***COULDNOTCOMPUTE***";
225218893Sdim    return;
226218893Sdim  default: break;
227218893Sdim  }
228218893Sdim  llvm_unreachable("Unknown SCEV kind!");
229218893Sdim}
230218893Sdim
231218893Sdimconst Type *SCEV::getType() const {
232218893Sdim  switch (getSCEVType()) {
233218893Sdim  case scConstant:
234218893Sdim    return cast<SCEVConstant>(this)->getType();
235218893Sdim  case scTruncate:
236218893Sdim  case scZeroExtend:
237218893Sdim  case scSignExtend:
238218893Sdim    return cast<SCEVCastExpr>(this)->getType();
239218893Sdim  case scAddRecExpr:
240218893Sdim  case scMulExpr:
241218893Sdim  case scUMaxExpr:
242218893Sdim  case scSMaxExpr:
243218893Sdim    return cast<SCEVNAryExpr>(this)->getType();
244218893Sdim  case scAddExpr:
245218893Sdim    return cast<SCEVAddExpr>(this)->getType();
246218893Sdim  case scUDivExpr:
247218893Sdim    return cast<SCEVUDivExpr>(this)->getType();
248218893Sdim  case scUnknown:
249218893Sdim    return cast<SCEVUnknown>(this)->getType();
250218893Sdim  case scCouldNotCompute:
251218893Sdim    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
252218893Sdim    return 0;
253218893Sdim  default: break;
254218893Sdim  }
255218893Sdim  llvm_unreachable("Unknown SCEV kind!");
256218893Sdim  return 0;
257218893Sdim}
258218893Sdim
259193323Sedbool SCEV::isZero() const {
260193323Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
261193323Sed    return SC->getValue()->isZero();
262193323Sed  return false;
263193323Sed}
264193323Sed
265193323Sedbool SCEV::isOne() const {
266193323Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
267193323Sed    return SC->getValue()->isOne();
268193323Sed  return false;
269193323Sed}
270193323Sed
271195098Sedbool SCEV::isAllOnesValue() const {
272195098Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
273195098Sed    return SC->getValue()->isAllOnesValue();
274195098Sed  return false;
275195098Sed}
276195098Sed
277194710SedSCEVCouldNotCompute::SCEVCouldNotCompute() :
278205407Srdivacky  SCEV(FoldingSetNodeIDRef(), scCouldNotCompute) {}
279193323Sed
280193323Sedbool SCEVCouldNotCompute::classof(const SCEV *S) {
281193323Sed  return S->getSCEVType() == scCouldNotCompute;
282193323Sed}
283193323Sed
284198090Srdivackyconst SCEV *ScalarEvolution::getConstant(ConstantInt *V) {
285195340Sed  FoldingSetNodeID ID;
286195340Sed  ID.AddInteger(scConstant);
287195340Sed  ID.AddPointer(V);
288195340Sed  void *IP = 0;
289195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
290205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVConstant(ID.Intern(SCEVAllocator), V);
291195340Sed  UniqueSCEVs.InsertNode(S, IP);
292195340Sed  return S;
293193323Sed}
294193323Sed
295198090Srdivackyconst SCEV *ScalarEvolution::getConstant(const APInt& Val) {
296198090Srdivacky  return getConstant(ConstantInt::get(getContext(), Val));
297193323Sed}
298193323Sed
299198090Srdivackyconst SCEV *
300194612SedScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
301207618Srdivacky  const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
302207618Srdivacky  return getConstant(ConstantInt::get(ITy, V, isSigned));
303194612Sed}
304194612Sed
305205407SrdivackySCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
306198090Srdivacky                           unsigned SCEVTy, const SCEV *op, const Type *ty)
307198090Srdivacky  : SCEV(ID, SCEVTy), Op(op), Ty(ty) {}
308193323Sed
309205407SrdivackySCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
310198090Srdivacky                                   const SCEV *op, const Type *ty)
311198090Srdivacky  : SCEVCastExpr(ID, scTruncate, op, ty) {
312204642Srdivacky  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
313204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
314193323Sed         "Cannot truncate non-integer value!");
315193323Sed}
316193323Sed
317205407SrdivackySCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
318198090Srdivacky                                       const SCEV *op, const Type *ty)
319198090Srdivacky  : SCEVCastExpr(ID, scZeroExtend, op, ty) {
320204642Srdivacky  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
321204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
322193323Sed         "Cannot zero extend non-integer value!");
323193323Sed}
324193323Sed
325205407SrdivackySCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
326198090Srdivacky                                       const SCEV *op, const Type *ty)
327198090Srdivacky  : SCEVCastExpr(ID, scSignExtend, op, ty) {
328204642Srdivacky  assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
329204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
330193323Sed         "Cannot sign extend non-integer value!");
331193323Sed}
332193323Sed
333212904Sdimvoid SCEVUnknown::deleted() {
334218893Sdim  // Clear this SCEVUnknown from various maps.
335218893Sdim  SE->forgetMemoizedResults(this);
336212904Sdim
337212904Sdim  // Remove this SCEVUnknown from the uniquing map.
338212904Sdim  SE->UniqueSCEVs.RemoveNode(this);
339212904Sdim
340212904Sdim  // Release the value.
341212904Sdim  setValPtr(0);
342212904Sdim}
343212904Sdim
344212904Sdimvoid SCEVUnknown::allUsesReplacedWith(Value *New) {
345218893Sdim  // Clear this SCEVUnknown from various maps.
346218893Sdim  SE->forgetMemoizedResults(this);
347212904Sdim
348212904Sdim  // Remove this SCEVUnknown from the uniquing map.
349212904Sdim  SE->UniqueSCEVs.RemoveNode(this);
350212904Sdim
351212904Sdim  // Update this SCEVUnknown to point to the new value. This is needed
352212904Sdim  // because there may still be outstanding SCEVs which still point to
353212904Sdim  // this SCEVUnknown.
354212904Sdim  setValPtr(New);
355212904Sdim}
356212904Sdim
357203954Srdivackybool SCEVUnknown::isSizeOf(const Type *&AllocTy) const {
358212904Sdim  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
359203954Srdivacky    if (VCE->getOpcode() == Instruction::PtrToInt)
360203954Srdivacky      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
361203954Srdivacky        if (CE->getOpcode() == Instruction::GetElementPtr &&
362203954Srdivacky            CE->getOperand(0)->isNullValue() &&
363203954Srdivacky            CE->getNumOperands() == 2)
364203954Srdivacky          if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(1)))
365203954Srdivacky            if (CI->isOne()) {
366203954Srdivacky              AllocTy = cast<PointerType>(CE->getOperand(0)->getType())
367203954Srdivacky                                 ->getElementType();
368203954Srdivacky              return true;
369203954Srdivacky            }
370203954Srdivacky
371203954Srdivacky  return false;
372203954Srdivacky}
373203954Srdivacky
374203954Srdivackybool SCEVUnknown::isAlignOf(const Type *&AllocTy) const {
375212904Sdim  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
376203954Srdivacky    if (VCE->getOpcode() == Instruction::PtrToInt)
377203954Srdivacky      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
378203954Srdivacky        if (CE->getOpcode() == Instruction::GetElementPtr &&
379203954Srdivacky            CE->getOperand(0)->isNullValue()) {
380203954Srdivacky          const Type *Ty =
381203954Srdivacky            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
382203954Srdivacky          if (const StructType *STy = dyn_cast<StructType>(Ty))
383203954Srdivacky            if (!STy->isPacked() &&
384203954Srdivacky                CE->getNumOperands() == 3 &&
385203954Srdivacky                CE->getOperand(1)->isNullValue()) {
386203954Srdivacky              if (ConstantInt *CI = dyn_cast<ConstantInt>(CE->getOperand(2)))
387203954Srdivacky                if (CI->isOne() &&
388203954Srdivacky                    STy->getNumElements() == 2 &&
389203954Srdivacky                    STy->getElementType(0)->isIntegerTy(1)) {
390203954Srdivacky                  AllocTy = STy->getElementType(1);
391203954Srdivacky                  return true;
392203954Srdivacky                }
393203954Srdivacky            }
394203954Srdivacky        }
395203954Srdivacky
396203954Srdivacky  return false;
397203954Srdivacky}
398203954Srdivacky
399203954Srdivackybool SCEVUnknown::isOffsetOf(const Type *&CTy, Constant *&FieldNo) const {
400212904Sdim  if (ConstantExpr *VCE = dyn_cast<ConstantExpr>(getValue()))
401203954Srdivacky    if (VCE->getOpcode() == Instruction::PtrToInt)
402203954Srdivacky      if (ConstantExpr *CE = dyn_cast<ConstantExpr>(VCE->getOperand(0)))
403203954Srdivacky        if (CE->getOpcode() == Instruction::GetElementPtr &&
404203954Srdivacky            CE->getNumOperands() == 3 &&
405203954Srdivacky            CE->getOperand(0)->isNullValue() &&
406203954Srdivacky            CE->getOperand(1)->isNullValue()) {
407203954Srdivacky          const Type *Ty =
408203954Srdivacky            cast<PointerType>(CE->getOperand(0)->getType())->getElementType();
409203954Srdivacky          // Ignore vector types here so that ScalarEvolutionExpander doesn't
410203954Srdivacky          // emit getelementptrs that index into vectors.
411204642Srdivacky          if (Ty->isStructTy() || Ty->isArrayTy()) {
412203954Srdivacky            CTy = Ty;
413203954Srdivacky            FieldNo = CE->getOperand(2);
414203954Srdivacky            return true;
415203954Srdivacky          }
416203954Srdivacky        }
417203954Srdivacky
418203954Srdivacky  return false;
419203954Srdivacky}
420203954Srdivacky
421193323Sed//===----------------------------------------------------------------------===//
422193323Sed//                               SCEV Utilities
423193323Sed//===----------------------------------------------------------------------===//
424193323Sed
425193323Sednamespace {
426193323Sed  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
427193323Sed  /// than the complexity of the RHS.  This comparator is used to canonicalize
428193323Sed  /// expressions.
429198892Srdivacky  class SCEVComplexityCompare {
430212904Sdim    const LoopInfo *const LI;
431193323Sed  public:
432212904Sdim    explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {}
433193323Sed
434212904Sdim    // Return true or false if LHS is less than, or at least RHS, respectively.
435193323Sed    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
436212904Sdim      return compare(LHS, RHS) < 0;
437212904Sdim    }
438212904Sdim
439212904Sdim    // Return negative, zero, or positive, if LHS is less than, equal to, or
440212904Sdim    // greater than RHS, respectively. A three-way result allows recursive
441212904Sdim    // comparisons to be more efficient.
442212904Sdim    int compare(const SCEV *LHS, const SCEV *RHS) const {
443198090Srdivacky      // Fast-path: SCEVs are uniqued so we can do a quick equality check.
444198090Srdivacky      if (LHS == RHS)
445212904Sdim        return 0;
446198090Srdivacky
447193323Sed      // Primarily, sort the SCEVs by their getSCEVType().
448212904Sdim      unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
449212904Sdim      if (LType != RType)
450212904Sdim        return (int)LType - (int)RType;
451193323Sed
452193323Sed      // Aside from the getSCEVType() ordering, the particular ordering
453193323Sed      // isn't very important except that it's beneficial to be consistent,
454193323Sed      // so that (a + b) and (b + a) don't end up as different expressions.
455212904Sdim      switch (LType) {
456212904Sdim      case scUnknown: {
457212904Sdim        const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
458193323Sed        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
459193323Sed
460212904Sdim        // Sort SCEVUnknown values with some loose heuristics. TODO: This is
461212904Sdim        // not as complete as it could be.
462212904Sdim        const Value *LV = LU->getValue(), *RV = RU->getValue();
463212904Sdim
464193323Sed        // Order pointer values after integer values. This helps SCEVExpander
465193323Sed        // form GEPs.
466212904Sdim        bool LIsPointer = LV->getType()->isPointerTy(),
467212904Sdim             RIsPointer = RV->getType()->isPointerTy();
468212904Sdim        if (LIsPointer != RIsPointer)
469212904Sdim          return (int)LIsPointer - (int)RIsPointer;
470193323Sed
471193323Sed        // Compare getValueID values.
472212904Sdim        unsigned LID = LV->getValueID(),
473212904Sdim                 RID = RV->getValueID();
474212904Sdim        if (LID != RID)
475212904Sdim          return (int)LID - (int)RID;
476193323Sed
477193323Sed        // Sort arguments by their position.
478212904Sdim        if (const Argument *LA = dyn_cast<Argument>(LV)) {
479212904Sdim          const Argument *RA = cast<Argument>(RV);
480212904Sdim          unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo();
481212904Sdim          return (int)LArgNo - (int)RArgNo;
482193323Sed        }
483193323Sed
484212904Sdim        // For instructions, compare their loop depth, and their operand
485212904Sdim        // count.  This is pretty loose.
486212904Sdim        if (const Instruction *LInst = dyn_cast<Instruction>(LV)) {
487212904Sdim          const Instruction *RInst = cast<Instruction>(RV);
488193323Sed
489193323Sed          // Compare loop depths.
490212904Sdim          const BasicBlock *LParent = LInst->getParent(),
491212904Sdim                           *RParent = RInst->getParent();
492212904Sdim          if (LParent != RParent) {
493212904Sdim            unsigned LDepth = LI->getLoopDepth(LParent),
494212904Sdim                     RDepth = LI->getLoopDepth(RParent);
495212904Sdim            if (LDepth != RDepth)
496212904Sdim              return (int)LDepth - (int)RDepth;
497212904Sdim          }
498193323Sed
499193323Sed          // Compare the number of operands.
500212904Sdim          unsigned LNumOps = LInst->getNumOperands(),
501212904Sdim                   RNumOps = RInst->getNumOperands();
502212904Sdim          return (int)LNumOps - (int)RNumOps;
503193323Sed        }
504193323Sed
505212904Sdim        return 0;
506193323Sed      }
507193323Sed
508212904Sdim      case scConstant: {
509212904Sdim        const SCEVConstant *LC = cast<SCEVConstant>(LHS);
510194612Sed        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
511212904Sdim
512212904Sdim        // Compare constant values.
513212904Sdim        const APInt &LA = LC->getValue()->getValue();
514212904Sdim        const APInt &RA = RC->getValue()->getValue();
515212904Sdim        unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth();
516212904Sdim        if (LBitWidth != RBitWidth)
517212904Sdim          return (int)LBitWidth - (int)RBitWidth;
518212904Sdim        return LA.ult(RA) ? -1 : 1;
519194612Sed      }
520193323Sed
521212904Sdim      case scAddRecExpr: {
522212904Sdim        const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS);
523194612Sed        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
524212904Sdim
525212904Sdim        // Compare addrec loop depths.
526212904Sdim        const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop();
527212904Sdim        if (LLoop != RLoop) {
528212904Sdim          unsigned LDepth = LLoop->getLoopDepth(),
529212904Sdim                   RDepth = RLoop->getLoopDepth();
530212904Sdim          if (LDepth != RDepth)
531212904Sdim            return (int)LDepth - (int)RDepth;
532212904Sdim        }
533212904Sdim
534212904Sdim        // Addrec complexity grows with operand count.
535212904Sdim        unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands();
536212904Sdim        if (LNumOps != RNumOps)
537212904Sdim          return (int)LNumOps - (int)RNumOps;
538212904Sdim
539212904Sdim        // Lexicographically compare.
540212904Sdim        for (unsigned i = 0; i != LNumOps; ++i) {
541212904Sdim          long X = compare(LA->getOperand(i), RA->getOperand(i));
542212904Sdim          if (X != 0)
543212904Sdim            return X;
544212904Sdim        }
545212904Sdim
546212904Sdim        return 0;
547194612Sed      }
548194612Sed
549212904Sdim      case scAddExpr:
550212904Sdim      case scMulExpr:
551212904Sdim      case scSMaxExpr:
552212904Sdim      case scUMaxExpr: {
553212904Sdim        const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS);
554193323Sed        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
555212904Sdim
556212904Sdim        // Lexicographically compare n-ary expressions.
557212904Sdim        unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands();
558212904Sdim        for (unsigned i = 0; i != LNumOps; ++i) {
559212904Sdim          if (i >= RNumOps)
560212904Sdim            return 1;
561212904Sdim          long X = compare(LC->getOperand(i), RC->getOperand(i));
562212904Sdim          if (X != 0)
563212904Sdim            return X;
564193323Sed        }
565212904Sdim        return (int)LNumOps - (int)RNumOps;
566193323Sed      }
567193323Sed
568212904Sdim      case scUDivExpr: {
569212904Sdim        const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS);
570193323Sed        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
571212904Sdim
572212904Sdim        // Lexicographically compare udiv expressions.
573212904Sdim        long X = compare(LC->getLHS(), RC->getLHS());
574212904Sdim        if (X != 0)
575212904Sdim          return X;
576212904Sdim        return compare(LC->getRHS(), RC->getRHS());
577193323Sed      }
578193323Sed
579212904Sdim      case scTruncate:
580212904Sdim      case scZeroExtend:
581212904Sdim      case scSignExtend: {
582212904Sdim        const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS);
583193323Sed        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
584212904Sdim
585212904Sdim        // Compare cast expressions by operand.
586212904Sdim        return compare(LC->getOperand(), RC->getOperand());
587193323Sed      }
588193323Sed
589212904Sdim      default:
590212904Sdim        break;
591212904Sdim      }
592212904Sdim
593198090Srdivacky      llvm_unreachable("Unknown SCEV kind!");
594212904Sdim      return 0;
595193323Sed    }
596193323Sed  };
597193323Sed}
598193323Sed
599193323Sed/// GroupByComplexity - Given a list of SCEV objects, order them by their
600193323Sed/// complexity, and group objects of the same complexity together by value.
601193323Sed/// When this routine is finished, we know that any duplicates in the vector are
602193323Sed/// consecutive and that complexity is monotonically increasing.
603193323Sed///
604204642Srdivacky/// Note that we go take special precautions to ensure that we get deterministic
605193323Sed/// results from this routine.  In other words, we don't want the results of
606193323Sed/// this to depend on where the addresses of various SCEV objects happened to
607193323Sed/// land in memory.
608193323Sed///
609198090Srdivackystatic void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,
610193323Sed                              LoopInfo *LI) {
611193323Sed  if (Ops.size() < 2) return;  // Noop
612193323Sed  if (Ops.size() == 2) {
613193323Sed    // This is the common case, which also happens to be trivially simple.
614193323Sed    // Special case it.
615212904Sdim    const SCEV *&LHS = Ops[0], *&RHS = Ops[1];
616212904Sdim    if (SCEVComplexityCompare(LI)(RHS, LHS))
617212904Sdim      std::swap(LHS, RHS);
618193323Sed    return;
619193323Sed  }
620193323Sed
621193323Sed  // Do the rough sort by complexity.
622193323Sed  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
623193323Sed
624193323Sed  // Now that we are sorted by complexity, group elements of the same
625193323Sed  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
626193323Sed  // be extremely short in practice.  Note that we take this approach because we
627193323Sed  // do not want to depend on the addresses of the objects we are grouping.
628193323Sed  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
629193323Sed    const SCEV *S = Ops[i];
630193323Sed    unsigned Complexity = S->getSCEVType();
631193323Sed
632193323Sed    // If there are any objects of the same complexity and same value as this
633193323Sed    // one, group them.
634193323Sed    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
635193323Sed      if (Ops[j] == S) { // Found a duplicate.
636193323Sed        // Move it to immediately after i'th element.
637193323Sed        std::swap(Ops[i+1], Ops[j]);
638193323Sed        ++i;   // no need to rescan it.
639193323Sed        if (i == e-2) return;  // Done!
640193323Sed      }
641193323Sed    }
642193323Sed  }
643193323Sed}
644193323Sed
645193323Sed
646193323Sed
647193323Sed//===----------------------------------------------------------------------===//
648193323Sed//                      Simple SCEV method implementations
649193323Sed//===----------------------------------------------------------------------===//
650193323Sed
651193323Sed/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
652193323Sed/// Assume, K > 0.
653198090Srdivackystatic const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,
654198090Srdivacky                                       ScalarEvolution &SE,
655198090Srdivacky                                       const Type* ResultTy) {
656193323Sed  // Handle the simplest case efficiently.
657193323Sed  if (K == 1)
658193323Sed    return SE.getTruncateOrZeroExtend(It, ResultTy);
659193323Sed
660193323Sed  // We are using the following formula for BC(It, K):
661193323Sed  //
662193323Sed  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
663193323Sed  //
664193323Sed  // Suppose, W is the bitwidth of the return value.  We must be prepared for
665193323Sed  // overflow.  Hence, we must assure that the result of our computation is
666193323Sed  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
667193323Sed  // safe in modular arithmetic.
668193323Sed  //
669193323Sed  // However, this code doesn't use exactly that formula; the formula it uses
670195098Sed  // is something like the following, where T is the number of factors of 2 in
671193323Sed  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
672193323Sed  // exponentiation:
673193323Sed  //
674193323Sed  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
675193323Sed  //
676193323Sed  // This formula is trivially equivalent to the previous formula.  However,
677193323Sed  // this formula can be implemented much more efficiently.  The trick is that
678193323Sed  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
679193323Sed  // arithmetic.  To do exact division in modular arithmetic, all we have
680193323Sed  // to do is multiply by the inverse.  Therefore, this step can be done at
681193323Sed  // width W.
682195098Sed  //
683193323Sed  // The next issue is how to safely do the division by 2^T.  The way this
684193323Sed  // is done is by doing the multiplication step at a width of at least W + T
685193323Sed  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
686193323Sed  // when we perform the division by 2^T (which is equivalent to a right shift
687193323Sed  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
688193323Sed  // truncated out after the division by 2^T.
689193323Sed  //
690193323Sed  // In comparison to just directly using the first formula, this technique
691193323Sed  // is much more efficient; using the first formula requires W * K bits,
692193323Sed  // but this formula less than W + K bits. Also, the first formula requires
693193323Sed  // a division step, whereas this formula only requires multiplies and shifts.
694193323Sed  //
695193323Sed  // It doesn't matter whether the subtraction step is done in the calculation
696193323Sed  // width or the input iteration count's width; if the subtraction overflows,
697193323Sed  // the result must be zero anyway.  We prefer here to do it in the width of
698193323Sed  // the induction variable because it helps a lot for certain cases; CodeGen
699193323Sed  // isn't smart enough to ignore the overflow, which leads to much less
700193323Sed  // efficient code if the width of the subtraction is wider than the native
701193323Sed  // register width.
702193323Sed  //
703193323Sed  // (It's possible to not widen at all by pulling out factors of 2 before
704193323Sed  // the multiplication; for example, K=2 can be calculated as
705193323Sed  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
706193323Sed  // extra arithmetic, so it's not an obvious win, and it gets
707193323Sed  // much more complicated for K > 3.)
708193323Sed
709193323Sed  // Protection from insane SCEVs; this bound is conservative,
710193323Sed  // but it probably doesn't matter.
711193323Sed  if (K > 1000)
712193323Sed    return SE.getCouldNotCompute();
713193323Sed
714193323Sed  unsigned W = SE.getTypeSizeInBits(ResultTy);
715193323Sed
716193323Sed  // Calculate K! / 2^T and T; we divide out the factors of two before
717193323Sed  // multiplying for calculating K! / 2^T to avoid overflow.
718193323Sed  // Other overflow doesn't matter because we only care about the bottom
719193323Sed  // W bits of the result.
720193323Sed  APInt OddFactorial(W, 1);
721193323Sed  unsigned T = 1;
722193323Sed  for (unsigned i = 3; i <= K; ++i) {
723193323Sed    APInt Mult(W, i);
724193323Sed    unsigned TwoFactors = Mult.countTrailingZeros();
725193323Sed    T += TwoFactors;
726193323Sed    Mult = Mult.lshr(TwoFactors);
727193323Sed    OddFactorial *= Mult;
728193323Sed  }
729193323Sed
730193323Sed  // We need at least W + T bits for the multiplication step
731193323Sed  unsigned CalculationBits = W + T;
732193323Sed
733204642Srdivacky  // Calculate 2^T, at width T+W.
734193323Sed  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
735193323Sed
736193323Sed  // Calculate the multiplicative inverse of K! / 2^T;
737193323Sed  // this multiplication factor will perform the exact division by
738193323Sed  // K! / 2^T.
739193323Sed  APInt Mod = APInt::getSignedMinValue(W+1);
740193323Sed  APInt MultiplyFactor = OddFactorial.zext(W+1);
741193323Sed  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
742193323Sed  MultiplyFactor = MultiplyFactor.trunc(W);
743193323Sed
744193323Sed  // Calculate the product, at width T+W
745198090Srdivacky  const IntegerType *CalculationTy = IntegerType::get(SE.getContext(),
746198090Srdivacky                                                      CalculationBits);
747198090Srdivacky  const SCEV *Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
748193323Sed  for (unsigned i = 1; i != K; ++i) {
749207618Srdivacky    const SCEV *S = SE.getMinusSCEV(It, SE.getConstant(It->getType(), i));
750193323Sed    Dividend = SE.getMulExpr(Dividend,
751193323Sed                             SE.getTruncateOrZeroExtend(S, CalculationTy));
752193323Sed  }
753193323Sed
754193323Sed  // Divide by 2^T
755198090Srdivacky  const SCEV *DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
756193323Sed
757193323Sed  // Truncate the result, and divide by K! / 2^T.
758193323Sed
759193323Sed  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
760193323Sed                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
761193323Sed}
762193323Sed
763193323Sed/// evaluateAtIteration - Return the value of this chain of recurrences at
764193323Sed/// the specified iteration number.  We can evaluate this recurrence by
765193323Sed/// multiplying each element in the chain by the binomial coefficient
766193323Sed/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
767193323Sed///
768193323Sed///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
769193323Sed///
770193323Sed/// where BC(It, k) stands for binomial coefficient.
771193323Sed///
772198090Srdivackyconst SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
773198090Srdivacky                                                ScalarEvolution &SE) const {
774198090Srdivacky  const SCEV *Result = getStart();
775193323Sed  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
776193323Sed    // The computation is correct in the face of overflow provided that the
777193323Sed    // multiplication is performed _after_ the evaluation of the binomial
778193323Sed    // coefficient.
779198090Srdivacky    const SCEV *Coeff = BinomialCoefficient(It, i, SE, getType());
780193323Sed    if (isa<SCEVCouldNotCompute>(Coeff))
781193323Sed      return Coeff;
782193323Sed
783193323Sed    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
784193323Sed  }
785193323Sed  return Result;
786193323Sed}
787193323Sed
788193323Sed//===----------------------------------------------------------------------===//
789193323Sed//                    SCEV Expression folder implementations
790193323Sed//===----------------------------------------------------------------------===//
791193323Sed
792198090Srdivackyconst SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
793198090Srdivacky                                             const Type *Ty) {
794193323Sed  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
795193323Sed         "This is not a truncating conversion!");
796193323Sed  assert(isSCEVable(Ty) &&
797193323Sed         "This is not a conversion to a SCEVable type!");
798193323Sed  Ty = getEffectiveSCEVType(Ty);
799193323Sed
800198090Srdivacky  FoldingSetNodeID ID;
801198090Srdivacky  ID.AddInteger(scTruncate);
802198090Srdivacky  ID.AddPointer(Op);
803198090Srdivacky  ID.AddPointer(Ty);
804198090Srdivacky  void *IP = 0;
805198090Srdivacky  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
806198090Srdivacky
807195340Sed  // Fold if the operand is constant.
808193323Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
809195098Sed    return getConstant(
810210299Sed      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(),
811210299Sed                                               getEffectiveSCEVType(Ty))));
812193323Sed
813193323Sed  // trunc(trunc(x)) --> trunc(x)
814193323Sed  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
815193323Sed    return getTruncateExpr(ST->getOperand(), Ty);
816193323Sed
817193323Sed  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
818193323Sed  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
819193323Sed    return getTruncateOrSignExtend(SS->getOperand(), Ty);
820193323Sed
821193323Sed  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
822193323Sed  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
823193323Sed    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
824193323Sed
825218893Sdim  // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
826218893Sdim  // eliminate all the truncates.
827218893Sdim  if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
828218893Sdim    SmallVector<const SCEV *, 4> Operands;
829218893Sdim    bool hasTrunc = false;
830218893Sdim    for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
831218893Sdim      const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
832218893Sdim      hasTrunc = isa<SCEVTruncateExpr>(S);
833218893Sdim      Operands.push_back(S);
834218893Sdim    }
835218893Sdim    if (!hasTrunc)
836221345Sdim      return getAddExpr(Operands);
837218893Sdim    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
838218893Sdim  }
839218893Sdim
840218893Sdim  // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
841218893Sdim  // eliminate all the truncates.
842218893Sdim  if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
843218893Sdim    SmallVector<const SCEV *, 4> Operands;
844218893Sdim    bool hasTrunc = false;
845218893Sdim    for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
846218893Sdim      const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
847218893Sdim      hasTrunc = isa<SCEVTruncateExpr>(S);
848218893Sdim      Operands.push_back(S);
849218893Sdim    }
850218893Sdim    if (!hasTrunc)
851221345Sdim      return getMulExpr(Operands);
852218893Sdim    UniqueSCEVs.FindNodeOrInsertPos(ID, IP);  // Mutates IP, returns NULL.
853218893Sdim  }
854218893Sdim
855194612Sed  // If the input value is a chrec scev, truncate the chrec's operands.
856193323Sed  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
857198090Srdivacky    SmallVector<const SCEV *, 4> Operands;
858193323Sed    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
859193323Sed      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
860221345Sdim    return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
861193323Sed  }
862193323Sed
863212904Sdim  // As a special case, fold trunc(undef) to undef. We don't want to
864212904Sdim  // know too much about SCEVUnknowns, but this special case is handy
865212904Sdim  // and harmless.
866212904Sdim  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
867212904Sdim    if (isa<UndefValue>(U->getValue()))
868212904Sdim      return getSCEV(UndefValue::get(Ty));
869212904Sdim
870210299Sed  // The cast wasn't folded; create an explicit cast node. We can reuse
871210299Sed  // the existing insert position since if we get here, we won't have
872210299Sed  // made any changes which would invalidate it.
873205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),
874205407Srdivacky                                                 Op, Ty);
875195340Sed  UniqueSCEVs.InsertNode(S, IP);
876195340Sed  return S;
877193323Sed}
878193323Sed
879198090Srdivackyconst SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op,
880198090Srdivacky                                               const Type *Ty) {
881193323Sed  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
882193323Sed         "This is not an extending conversion!");
883193323Sed  assert(isSCEVable(Ty) &&
884193323Sed         "This is not a conversion to a SCEVable type!");
885193323Sed  Ty = getEffectiveSCEVType(Ty);
886193323Sed
887195340Sed  // Fold if the operand is constant.
888210299Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
889210299Sed    return getConstant(
890210299Sed      cast<ConstantInt>(ConstantExpr::getZExt(SC->getValue(),
891210299Sed                                              getEffectiveSCEVType(Ty))));
892193323Sed
893193323Sed  // zext(zext(x)) --> zext(x)
894193323Sed  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
895193323Sed    return getZeroExtendExpr(SZ->getOperand(), Ty);
896193323Sed
897198090Srdivacky  // Before doing any expensive analysis, check to see if we've already
898198090Srdivacky  // computed a SCEV for this Op and Ty.
899198090Srdivacky  FoldingSetNodeID ID;
900198090Srdivacky  ID.AddInteger(scZeroExtend);
901198090Srdivacky  ID.AddPointer(Op);
902198090Srdivacky  ID.AddPointer(Ty);
903198090Srdivacky  void *IP = 0;
904198090Srdivacky  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
905198090Srdivacky
906218893Sdim  // zext(trunc(x)) --> zext(x) or x or trunc(x)
907218893Sdim  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
908218893Sdim    // It's possible the bits taken off by the truncate were all zero bits. If
909218893Sdim    // so, we should be able to simplify this further.
910218893Sdim    const SCEV *X = ST->getOperand();
911218893Sdim    ConstantRange CR = getUnsignedRange(X);
912218893Sdim    unsigned TruncBits = getTypeSizeInBits(ST->getType());
913218893Sdim    unsigned NewBits = getTypeSizeInBits(Ty);
914218893Sdim    if (CR.truncate(TruncBits).zeroExtend(NewBits).contains(
915218893Sdim            CR.zextOrTrunc(NewBits)))
916218893Sdim      return getTruncateOrZeroExtend(X, Ty);
917218893Sdim  }
918218893Sdim
919193323Sed  // If the input value is a chrec scev, and we can prove that the value
920193323Sed  // did not overflow the old, smaller, value, we can zero extend all of the
921193323Sed  // operands (often constants).  This allows analysis of something like
922193323Sed  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
923193323Sed  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
924193323Sed    if (AR->isAffine()) {
925198090Srdivacky      const SCEV *Start = AR->getStart();
926198090Srdivacky      const SCEV *Step = AR->getStepRecurrence(*this);
927198090Srdivacky      unsigned BitWidth = getTypeSizeInBits(AR->getType());
928198090Srdivacky      const Loop *L = AR->getLoop();
929198090Srdivacky
930198090Srdivacky      // If we have special knowledge that this addrec won't overflow,
931198090Srdivacky      // we don't need to do any further analysis.
932221345Sdim      if (AR->getNoWrapFlags(SCEV::FlagNUW))
933198090Srdivacky        return getAddRecExpr(getZeroExtendExpr(Start, Ty),
934198090Srdivacky                             getZeroExtendExpr(Step, Ty),
935221345Sdim                             L, AR->getNoWrapFlags());
936198090Srdivacky
937193323Sed      // Check whether the backedge-taken count is SCEVCouldNotCompute.
938193323Sed      // Note that this serves two purposes: It filters out loops that are
939193323Sed      // simply not analyzable, and it covers the case where this code is
940193323Sed      // being called from within backedge-taken count analysis, such that
941193323Sed      // attempting to ask for the backedge-taken count would likely result
942193323Sed      // in infinite recursion. In the later case, the analysis code will
943193323Sed      // cope with a conservative value, and it will take care to purge
944193323Sed      // that value once it has finished.
945198090Srdivacky      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
946193323Sed      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
947193323Sed        // Manually compute the final value for AR, checking for
948193323Sed        // overflow.
949193323Sed
950193323Sed        // Check whether the backedge-taken count can be losslessly casted to
951193323Sed        // the addrec's type. The count is always unsigned.
952198090Srdivacky        const SCEV *CastedMaxBECount =
953193323Sed          getTruncateOrZeroExtend(MaxBECount, Start->getType());
954198090Srdivacky        const SCEV *RecastedMaxBECount =
955193323Sed          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
956193323Sed        if (MaxBECount == RecastedMaxBECount) {
957198090Srdivacky          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
958193323Sed          // Check whether Start+Step*MaxBECount has no unsigned overflow.
959204642Srdivacky          const SCEV *ZMul = getMulExpr(CastedMaxBECount, Step);
960198090Srdivacky          const SCEV *Add = getAddExpr(Start, ZMul);
961198090Srdivacky          const SCEV *OperandExtendedAdd =
962193323Sed            getAddExpr(getZeroExtendExpr(Start, WideTy),
963193323Sed                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
964193323Sed                                  getZeroExtendExpr(Step, WideTy)));
965221345Sdim          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
966221345Sdim            // Cache knowledge of AR NUW, which is propagated to this AddRec.
967221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
968193323Sed            // Return the expression with the addrec on the outside.
969193323Sed            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
970193323Sed                                 getZeroExtendExpr(Step, Ty),
971221345Sdim                                 L, AR->getNoWrapFlags());
972221345Sdim          }
973193323Sed          // Similar to above, only this time treat the step value as signed.
974193323Sed          // This covers loops that count down.
975204642Srdivacky          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
976193323Sed          Add = getAddExpr(Start, SMul);
977193323Sed          OperandExtendedAdd =
978193323Sed            getAddExpr(getZeroExtendExpr(Start, WideTy),
979193323Sed                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
980193323Sed                                  getSignExtendExpr(Step, WideTy)));
981221345Sdim          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd) {
982221345Sdim            // Cache knowledge of AR NW, which is propagated to this AddRec.
983221345Sdim            // Negative step causes unsigned wrap, but it still can't self-wrap.
984221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
985193323Sed            // Return the expression with the addrec on the outside.
986193323Sed            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
987193323Sed                                 getSignExtendExpr(Step, Ty),
988221345Sdim                                 L, AR->getNoWrapFlags());
989221345Sdim          }
990193323Sed        }
991198090Srdivacky
992198090Srdivacky        // If the backedge is guarded by a comparison with the pre-inc value
993198090Srdivacky        // the addrec is safe. Also, if the entry is guarded by a comparison
994198090Srdivacky        // with the start value and the backedge is guarded by a comparison
995198090Srdivacky        // with the post-inc value, the addrec is safe.
996198090Srdivacky        if (isKnownPositive(Step)) {
997198090Srdivacky          const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
998198090Srdivacky                                      getUnsignedRange(Step).getUnsignedMax());
999198090Srdivacky          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
1000207618Srdivacky              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
1001198090Srdivacky               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
1002221345Sdim                                           AR->getPostIncExpr(*this), N))) {
1003221345Sdim            // Cache knowledge of AR NUW, which is propagated to this AddRec.
1004221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
1005198090Srdivacky            // Return the expression with the addrec on the outside.
1006198090Srdivacky            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1007198090Srdivacky                                 getZeroExtendExpr(Step, Ty),
1008221345Sdim                                 L, AR->getNoWrapFlags());
1009221345Sdim          }
1010198090Srdivacky        } else if (isKnownNegative(Step)) {
1011198090Srdivacky          const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
1012198090Srdivacky                                      getSignedRange(Step).getSignedMin());
1013207618Srdivacky          if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
1014207618Srdivacky              (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
1015198090Srdivacky               isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
1016221345Sdim                                           AR->getPostIncExpr(*this), N))) {
1017221345Sdim            // Cache knowledge of AR NW, which is propagated to this AddRec.
1018221345Sdim            // Negative step causes unsigned wrap, but it still can't self-wrap.
1019221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
1020198090Srdivacky            // Return the expression with the addrec on the outside.
1021198090Srdivacky            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
1022198090Srdivacky                                 getSignExtendExpr(Step, Ty),
1023221345Sdim                                 L, AR->getNoWrapFlags());
1024221345Sdim          }
1025198090Srdivacky        }
1026193323Sed      }
1027193323Sed    }
1028193323Sed
1029198090Srdivacky  // The cast wasn't folded; create an explicit cast node.
1030198090Srdivacky  // Recompute the insert position, as it may have been invalidated.
1031195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1032205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),
1033205407Srdivacky                                                   Op, Ty);
1034195340Sed  UniqueSCEVs.InsertNode(S, IP);
1035195340Sed  return S;
1036193323Sed}
1037193323Sed
1038223017Sdim// Get the limit of a recurrence such that incrementing by Step cannot cause
1039223017Sdim// signed overflow as long as the value of the recurrence within the loop does
1040223017Sdim// not exceed this limit before incrementing.
1041223017Sdimstatic const SCEV *getOverflowLimitForStep(const SCEV *Step,
1042223017Sdim                                           ICmpInst::Predicate *Pred,
1043223017Sdim                                           ScalarEvolution *SE) {
1044223017Sdim  unsigned BitWidth = SE->getTypeSizeInBits(Step->getType());
1045223017Sdim  if (SE->isKnownPositive(Step)) {
1046223017Sdim    *Pred = ICmpInst::ICMP_SLT;
1047223017Sdim    return SE->getConstant(APInt::getSignedMinValue(BitWidth) -
1048223017Sdim                           SE->getSignedRange(Step).getSignedMax());
1049223017Sdim  }
1050223017Sdim  if (SE->isKnownNegative(Step)) {
1051223017Sdim    *Pred = ICmpInst::ICMP_SGT;
1052223017Sdim    return SE->getConstant(APInt::getSignedMaxValue(BitWidth) -
1053223017Sdim                       SE->getSignedRange(Step).getSignedMin());
1054223017Sdim  }
1055223017Sdim  return 0;
1056223017Sdim}
1057223017Sdim
1058223017Sdim// The recurrence AR has been shown to have no signed wrap. Typically, if we can
1059223017Sdim// prove NSW for AR, then we can just as easily prove NSW for its preincrement
1060223017Sdim// or postincrement sibling. This allows normalizing a sign extended AddRec as
1061223017Sdim// such: {sext(Step + Start),+,Step} => {(Step + sext(Start),+,Step} As a
1062223017Sdim// result, the expression "Step + sext(PreIncAR)" is congruent with
1063223017Sdim// "sext(PostIncAR)"
1064223017Sdimstatic const SCEV *getPreStartForSignExtend(const SCEVAddRecExpr *AR,
1065223017Sdim                                            const Type *Ty,
1066223017Sdim                                            ScalarEvolution *SE) {
1067223017Sdim  const Loop *L = AR->getLoop();
1068223017Sdim  const SCEV *Start = AR->getStart();
1069223017Sdim  const SCEV *Step = AR->getStepRecurrence(*SE);
1070223017Sdim
1071223017Sdim  // Check for a simple looking step prior to loop entry.
1072223017Sdim  const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Start);
1073223017Sdim  if (!SA || SA->getNumOperands() != 2 || SA->getOperand(0) != Step)
1074223017Sdim    return 0;
1075223017Sdim
1076223017Sdim  // This is a postinc AR. Check for overflow on the preinc recurrence using the
1077223017Sdim  // same three conditions that getSignExtendedExpr checks.
1078223017Sdim
1079223017Sdim  // 1. NSW flags on the step increment.
1080223017Sdim  const SCEV *PreStart = SA->getOperand(1);
1081223017Sdim  const SCEVAddRecExpr *PreAR = dyn_cast<SCEVAddRecExpr>(
1082223017Sdim    SE->getAddRecExpr(PreStart, Step, L, SCEV::FlagAnyWrap));
1083223017Sdim
1084223017Sdim  if (PreAR && PreAR->getNoWrapFlags(SCEV::FlagNSW))
1085223017Sdim    return PreStart;
1086223017Sdim
1087223017Sdim  // 2. Direct overflow check on the step operation's expression.
1088223017Sdim  unsigned BitWidth = SE->getTypeSizeInBits(AR->getType());
1089223017Sdim  const Type *WideTy = IntegerType::get(SE->getContext(), BitWidth * 2);
1090223017Sdim  const SCEV *OperandExtendedStart =
1091223017Sdim    SE->getAddExpr(SE->getSignExtendExpr(PreStart, WideTy),
1092223017Sdim                   SE->getSignExtendExpr(Step, WideTy));
1093223017Sdim  if (SE->getSignExtendExpr(Start, WideTy) == OperandExtendedStart) {
1094223017Sdim    // Cache knowledge of PreAR NSW.
1095223017Sdim    if (PreAR)
1096223017Sdim      const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(SCEV::FlagNSW);
1097223017Sdim    // FIXME: this optimization needs a unit test
1098223017Sdim    DEBUG(dbgs() << "SCEV: untested prestart overflow check\n");
1099223017Sdim    return PreStart;
1100223017Sdim  }
1101223017Sdim
1102223017Sdim  // 3. Loop precondition.
1103223017Sdim  ICmpInst::Predicate Pred;
1104223017Sdim  const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, SE);
1105223017Sdim
1106223017Sdim  if (OverflowLimit &&
1107223017Sdim      SE->isLoopEntryGuardedByCond(L, Pred, PreStart, OverflowLimit)) {
1108223017Sdim    return PreStart;
1109223017Sdim  }
1110223017Sdim  return 0;
1111223017Sdim}
1112223017Sdim
1113223017Sdim// Get the normalized sign-extended expression for this AddRec's Start.
1114223017Sdimstatic const SCEV *getSignExtendAddRecStart(const SCEVAddRecExpr *AR,
1115223017Sdim                                            const Type *Ty,
1116223017Sdim                                            ScalarEvolution *SE) {
1117223017Sdim  const SCEV *PreStart = getPreStartForSignExtend(AR, Ty, SE);
1118223017Sdim  if (!PreStart)
1119223017Sdim    return SE->getSignExtendExpr(AR->getStart(), Ty);
1120223017Sdim
1121223017Sdim  return SE->getAddExpr(SE->getSignExtendExpr(AR->getStepRecurrence(*SE), Ty),
1122223017Sdim                        SE->getSignExtendExpr(PreStart, Ty));
1123223017Sdim}
1124223017Sdim
1125198090Srdivackyconst SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op,
1126198090Srdivacky                                               const Type *Ty) {
1127193323Sed  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1128193323Sed         "This is not an extending conversion!");
1129193323Sed  assert(isSCEVable(Ty) &&
1130193323Sed         "This is not a conversion to a SCEVable type!");
1131193323Sed  Ty = getEffectiveSCEVType(Ty);
1132193323Sed
1133195340Sed  // Fold if the operand is constant.
1134210299Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1135210299Sed    return getConstant(
1136210299Sed      cast<ConstantInt>(ConstantExpr::getSExt(SC->getValue(),
1137210299Sed                                              getEffectiveSCEVType(Ty))));
1138193323Sed
1139193323Sed  // sext(sext(x)) --> sext(x)
1140193323Sed  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
1141193323Sed    return getSignExtendExpr(SS->getOperand(), Ty);
1142193323Sed
1143218893Sdim  // sext(zext(x)) --> zext(x)
1144218893Sdim  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
1145218893Sdim    return getZeroExtendExpr(SZ->getOperand(), Ty);
1146218893Sdim
1147198090Srdivacky  // Before doing any expensive analysis, check to see if we've already
1148198090Srdivacky  // computed a SCEV for this Op and Ty.
1149198090Srdivacky  FoldingSetNodeID ID;
1150198090Srdivacky  ID.AddInteger(scSignExtend);
1151198090Srdivacky  ID.AddPointer(Op);
1152198090Srdivacky  ID.AddPointer(Ty);
1153198090Srdivacky  void *IP = 0;
1154198090Srdivacky  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1155198090Srdivacky
1156218893Sdim  // If the input value is provably positive, build a zext instead.
1157218893Sdim  if (isKnownNonNegative(Op))
1158218893Sdim    return getZeroExtendExpr(Op, Ty);
1159218893Sdim
1160218893Sdim  // sext(trunc(x)) --> sext(x) or x or trunc(x)
1161218893Sdim  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) {
1162218893Sdim    // It's possible the bits taken off by the truncate were all sign bits. If
1163218893Sdim    // so, we should be able to simplify this further.
1164218893Sdim    const SCEV *X = ST->getOperand();
1165218893Sdim    ConstantRange CR = getSignedRange(X);
1166218893Sdim    unsigned TruncBits = getTypeSizeInBits(ST->getType());
1167218893Sdim    unsigned NewBits = getTypeSizeInBits(Ty);
1168218893Sdim    if (CR.truncate(TruncBits).signExtend(NewBits).contains(
1169218893Sdim            CR.sextOrTrunc(NewBits)))
1170218893Sdim      return getTruncateOrSignExtend(X, Ty);
1171218893Sdim  }
1172218893Sdim
1173193323Sed  // If the input value is a chrec scev, and we can prove that the value
1174193323Sed  // did not overflow the old, smaller, value, we can sign extend all of the
1175193323Sed  // operands (often constants).  This allows analysis of something like
1176193323Sed  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
1177193323Sed  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
1178193323Sed    if (AR->isAffine()) {
1179198090Srdivacky      const SCEV *Start = AR->getStart();
1180198090Srdivacky      const SCEV *Step = AR->getStepRecurrence(*this);
1181198090Srdivacky      unsigned BitWidth = getTypeSizeInBits(AR->getType());
1182198090Srdivacky      const Loop *L = AR->getLoop();
1183198090Srdivacky
1184198090Srdivacky      // If we have special knowledge that this addrec won't overflow,
1185198090Srdivacky      // we don't need to do any further analysis.
1186221345Sdim      if (AR->getNoWrapFlags(SCEV::FlagNSW))
1187223017Sdim        return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1188198090Srdivacky                             getSignExtendExpr(Step, Ty),
1189221345Sdim                             L, SCEV::FlagNSW);
1190198090Srdivacky
1191193323Sed      // Check whether the backedge-taken count is SCEVCouldNotCompute.
1192193323Sed      // Note that this serves two purposes: It filters out loops that are
1193193323Sed      // simply not analyzable, and it covers the case where this code is
1194193323Sed      // being called from within backedge-taken count analysis, such that
1195193323Sed      // attempting to ask for the backedge-taken count would likely result
1196193323Sed      // in infinite recursion. In the later case, the analysis code will
1197193323Sed      // cope with a conservative value, and it will take care to purge
1198193323Sed      // that value once it has finished.
1199198090Srdivacky      const SCEV *MaxBECount = getMaxBackedgeTakenCount(L);
1200193323Sed      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
1201193323Sed        // Manually compute the final value for AR, checking for
1202193323Sed        // overflow.
1203193323Sed
1204193323Sed        // Check whether the backedge-taken count can be losslessly casted to
1205193323Sed        // the addrec's type. The count is always unsigned.
1206198090Srdivacky        const SCEV *CastedMaxBECount =
1207193323Sed          getTruncateOrZeroExtend(MaxBECount, Start->getType());
1208198090Srdivacky        const SCEV *RecastedMaxBECount =
1209193323Sed          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
1210193323Sed        if (MaxBECount == RecastedMaxBECount) {
1211198090Srdivacky          const Type *WideTy = IntegerType::get(getContext(), BitWidth * 2);
1212193323Sed          // Check whether Start+Step*MaxBECount has no signed overflow.
1213204642Srdivacky          const SCEV *SMul = getMulExpr(CastedMaxBECount, Step);
1214198090Srdivacky          const SCEV *Add = getAddExpr(Start, SMul);
1215198090Srdivacky          const SCEV *OperandExtendedAdd =
1216193323Sed            getAddExpr(getSignExtendExpr(Start, WideTy),
1217193323Sed                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1218193323Sed                                  getSignExtendExpr(Step, WideTy)));
1219221345Sdim          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1220221345Sdim            // Cache knowledge of AR NSW, which is propagated to this AddRec.
1221221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1222193323Sed            // Return the expression with the addrec on the outside.
1223223017Sdim            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1224193323Sed                                 getSignExtendExpr(Step, Ty),
1225221345Sdim                                 L, AR->getNoWrapFlags());
1226221345Sdim          }
1227198090Srdivacky          // Similar to above, only this time treat the step value as unsigned.
1228198090Srdivacky          // This covers loops that count up with an unsigned step.
1229204642Srdivacky          const SCEV *UMul = getMulExpr(CastedMaxBECount, Step);
1230198090Srdivacky          Add = getAddExpr(Start, UMul);
1231198090Srdivacky          OperandExtendedAdd =
1232198090Srdivacky            getAddExpr(getSignExtendExpr(Start, WideTy),
1233198090Srdivacky                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
1234198090Srdivacky                                  getZeroExtendExpr(Step, WideTy)));
1235221345Sdim          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd) {
1236221345Sdim            // Cache knowledge of AR NSW, which is propagated to this AddRec.
1237221345Sdim            const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1238198090Srdivacky            // Return the expression with the addrec on the outside.
1239223017Sdim            return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1240198090Srdivacky                                 getZeroExtendExpr(Step, Ty),
1241221345Sdim                                 L, AR->getNoWrapFlags());
1242221345Sdim          }
1243193323Sed        }
1244198090Srdivacky
1245198090Srdivacky        // If the backedge is guarded by a comparison with the pre-inc value
1246198090Srdivacky        // the addrec is safe. Also, if the entry is guarded by a comparison
1247198090Srdivacky        // with the start value and the backedge is guarded by a comparison
1248198090Srdivacky        // with the post-inc value, the addrec is safe.
1249223017Sdim        ICmpInst::Predicate Pred;
1250223017Sdim        const SCEV *OverflowLimit = getOverflowLimitForStep(Step, &Pred, this);
1251223017Sdim        if (OverflowLimit &&
1252223017Sdim            (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
1253223017Sdim             (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
1254223017Sdim              isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
1255223017Sdim                                          OverflowLimit)))) {
1256223017Sdim          // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
1257223017Sdim          const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
1258223017Sdim          return getAddRecExpr(getSignExtendAddRecStart(AR, Ty, this),
1259223017Sdim                               getSignExtendExpr(Step, Ty),
1260223017Sdim                               L, AR->getNoWrapFlags());
1261198090Srdivacky        }
1262193323Sed      }
1263193323Sed    }
1264193323Sed
1265198090Srdivacky  // The cast wasn't folded; create an explicit cast node.
1266198090Srdivacky  // Recompute the insert position, as it may have been invalidated.
1267195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1268205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),
1269205407Srdivacky                                                   Op, Ty);
1270195340Sed  UniqueSCEVs.InsertNode(S, IP);
1271195340Sed  return S;
1272193323Sed}
1273193323Sed
1274194178Sed/// getAnyExtendExpr - Return a SCEV for the given operand extended with
1275194178Sed/// unspecified bits out to the given type.
1276194178Sed///
1277198090Srdivackyconst SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,
1278198090Srdivacky                                              const Type *Ty) {
1279194178Sed  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
1280194178Sed         "This is not an extending conversion!");
1281194178Sed  assert(isSCEVable(Ty) &&
1282194178Sed         "This is not a conversion to a SCEVable type!");
1283194178Sed  Ty = getEffectiveSCEVType(Ty);
1284194178Sed
1285194178Sed  // Sign-extend negative constants.
1286194178Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
1287194178Sed    if (SC->getValue()->getValue().isNegative())
1288194178Sed      return getSignExtendExpr(Op, Ty);
1289194178Sed
1290194178Sed  // Peel off a truncate cast.
1291194178Sed  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
1292198090Srdivacky    const SCEV *NewOp = T->getOperand();
1293194178Sed    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
1294194178Sed      return getAnyExtendExpr(NewOp, Ty);
1295194178Sed    return getTruncateOrNoop(NewOp, Ty);
1296194178Sed  }
1297194178Sed
1298194178Sed  // Next try a zext cast. If the cast is folded, use it.
1299198090Srdivacky  const SCEV *ZExt = getZeroExtendExpr(Op, Ty);
1300194178Sed  if (!isa<SCEVZeroExtendExpr>(ZExt))
1301194178Sed    return ZExt;
1302194178Sed
1303194178Sed  // Next try a sext cast. If the cast is folded, use it.
1304198090Srdivacky  const SCEV *SExt = getSignExtendExpr(Op, Ty);
1305194178Sed  if (!isa<SCEVSignExtendExpr>(SExt))
1306194178Sed    return SExt;
1307194178Sed
1308202878Srdivacky  // Force the cast to be folded into the operands of an addrec.
1309202878Srdivacky  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
1310202878Srdivacky    SmallVector<const SCEV *, 4> Ops;
1311202878Srdivacky    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
1312202878Srdivacky         I != E; ++I)
1313202878Srdivacky      Ops.push_back(getAnyExtendExpr(*I, Ty));
1314221345Sdim    return getAddRecExpr(Ops, AR->getLoop(), SCEV::FlagNW);
1315202878Srdivacky  }
1316202878Srdivacky
1317212904Sdim  // As a special case, fold anyext(undef) to undef. We don't want to
1318212904Sdim  // know too much about SCEVUnknowns, but this special case is handy
1319212904Sdim  // and harmless.
1320212904Sdim  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Op))
1321212904Sdim    if (isa<UndefValue>(U->getValue()))
1322212904Sdim      return getSCEV(UndefValue::get(Ty));
1323212904Sdim
1324194178Sed  // If the expression is obviously signed, use the sext cast value.
1325194178Sed  if (isa<SCEVSMaxExpr>(Op))
1326194178Sed    return SExt;
1327194178Sed
1328194178Sed  // Absent any other information, use the zext cast value.
1329194178Sed  return ZExt;
1330194178Sed}
1331194178Sed
1332194612Sed/// CollectAddOperandsWithScales - Process the given Ops list, which is
1333194612Sed/// a list of operands to be added under the given scale, update the given
1334194612Sed/// map. This is a helper function for getAddRecExpr. As an example of
1335194612Sed/// what it does, given a sequence of operands that would form an add
1336194612Sed/// expression like this:
1337194612Sed///
1338194612Sed///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1339194612Sed///
1340194612Sed/// where A and B are constants, update the map with these values:
1341194612Sed///
1342194612Sed///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1343194612Sed///
1344194612Sed/// and add 13 + A*B*29 to AccumulatedConstant.
1345194612Sed/// This will allow getAddRecExpr to produce this:
1346194612Sed///
1347194612Sed///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1348194612Sed///
1349194612Sed/// This form often exposes folding opportunities that are hidden in
1350194612Sed/// the original operand list.
1351194612Sed///
1352194612Sed/// Return true iff it appears that any interesting folding opportunities
1353194612Sed/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1354194612Sed/// the common case where no interesting opportunities are present, and
1355194612Sed/// is also used as a check to avoid infinite recursion.
1356194612Sed///
1357194612Sedstatic bool
1358198090SrdivackyCollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
1359198090Srdivacky                             SmallVector<const SCEV *, 8> &NewOps,
1360194612Sed                             APInt &AccumulatedConstant,
1361205407Srdivacky                             const SCEV *const *Ops, size_t NumOperands,
1362194612Sed                             const APInt &Scale,
1363194612Sed                             ScalarEvolution &SE) {
1364194612Sed  bool Interesting = false;
1365194612Sed
1366210299Sed  // Iterate over the add operands. They are sorted, with constants first.
1367210299Sed  unsigned i = 0;
1368210299Sed  while (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1369210299Sed    ++i;
1370210299Sed    // Pull a buried constant out to the outside.
1371210299Sed    if (Scale != 1 || AccumulatedConstant != 0 || C->getValue()->isZero())
1372210299Sed      Interesting = true;
1373210299Sed    AccumulatedConstant += Scale * C->getValue()->getValue();
1374210299Sed  }
1375210299Sed
1376210299Sed  // Next comes everything else. We're especially interested in multiplies
1377210299Sed  // here, but they're in the middle, so just visit the rest with one loop.
1378210299Sed  for (; i != NumOperands; ++i) {
1379194612Sed    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1380194612Sed    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1381194612Sed      APInt NewScale =
1382194612Sed        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1383194612Sed      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1384194612Sed        // A multiplication of a constant with another add; recurse.
1385205407Srdivacky        const SCEVAddExpr *Add = cast<SCEVAddExpr>(Mul->getOperand(1));
1386194612Sed        Interesting |=
1387194612Sed          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1388205407Srdivacky                                       Add->op_begin(), Add->getNumOperands(),
1389194612Sed                                       NewScale, SE);
1390194612Sed      } else {
1391194612Sed        // A multiplication of a constant with some other value. Update
1392194612Sed        // the map.
1393198090Srdivacky        SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1394198090Srdivacky        const SCEV *Key = SE.getMulExpr(MulOps);
1395198090Srdivacky        std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1396195340Sed          M.insert(std::make_pair(Key, NewScale));
1397194612Sed        if (Pair.second) {
1398194612Sed          NewOps.push_back(Pair.first->first);
1399194612Sed        } else {
1400194612Sed          Pair.first->second += NewScale;
1401194612Sed          // The map already had an entry for this value, which may indicate
1402194612Sed          // a folding opportunity.
1403194612Sed          Interesting = true;
1404194612Sed        }
1405194612Sed      }
1406194612Sed    } else {
1407194612Sed      // An ordinary operand. Update the map.
1408198090Srdivacky      std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair =
1409195340Sed        M.insert(std::make_pair(Ops[i], Scale));
1410194612Sed      if (Pair.second) {
1411194612Sed        NewOps.push_back(Pair.first->first);
1412194612Sed      } else {
1413194612Sed        Pair.first->second += Scale;
1414194612Sed        // The map already had an entry for this value, which may indicate
1415194612Sed        // a folding opportunity.
1416194612Sed        Interesting = true;
1417194612Sed      }
1418194612Sed    }
1419194612Sed  }
1420194612Sed
1421194612Sed  return Interesting;
1422194612Sed}
1423194612Sed
1424194612Sednamespace {
1425194612Sed  struct APIntCompare {
1426194612Sed    bool operator()(const APInt &LHS, const APInt &RHS) const {
1427194612Sed      return LHS.ult(RHS);
1428194612Sed    }
1429194612Sed  };
1430194612Sed}
1431194612Sed
1432193323Sed/// getAddExpr - Get a canonical add expression, or something simpler if
1433193323Sed/// possible.
1434198090Srdivackyconst SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
1435221345Sdim                                        SCEV::NoWrapFlags Flags) {
1436221345Sdim  assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
1437221345Sdim         "only nuw or nsw allowed");
1438193323Sed  assert(!Ops.empty() && "Cannot get empty add!");
1439193323Sed  if (Ops.size() == 1) return Ops[0];
1440193323Sed#ifndef NDEBUG
1441210299Sed  const Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1442193323Sed  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1443210299Sed    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1444193323Sed           "SCEVAddExpr operand types don't match!");
1445193323Sed#endif
1446193323Sed
1447221345Sdim  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1448221345Sdim  // And vice-versa.
1449221345Sdim  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1450221345Sdim  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1451221345Sdim  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1452202878Srdivacky    bool All = true;
1453212904Sdim    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1454212904Sdim         E = Ops.end(); I != E; ++I)
1455212904Sdim      if (!isKnownNonNegative(*I)) {
1456202878Srdivacky        All = false;
1457202878Srdivacky        break;
1458202878Srdivacky      }
1459221345Sdim    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1460202878Srdivacky  }
1461202878Srdivacky
1462193323Sed  // Sort by complexity, this groups all similar expression types together.
1463193323Sed  GroupByComplexity(Ops, LI);
1464193323Sed
1465193323Sed  // If there are any constants, fold them together.
1466193323Sed  unsigned Idx = 0;
1467193323Sed  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1468193323Sed    ++Idx;
1469193323Sed    assert(Idx < Ops.size());
1470193323Sed    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1471193323Sed      // We found two constants, fold them together!
1472194612Sed      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1473194612Sed                           RHSC->getValue()->getValue());
1474194612Sed      if (Ops.size() == 2) return Ops[0];
1475193323Sed      Ops.erase(Ops.begin()+1);  // Erase the folded element
1476193323Sed      LHSC = cast<SCEVConstant>(Ops[0]);
1477193323Sed    }
1478193323Sed
1479193323Sed    // If we are left with a constant zero being added, strip it off.
1480207618Srdivacky    if (LHSC->getValue()->isZero()) {
1481193323Sed      Ops.erase(Ops.begin());
1482193323Sed      --Idx;
1483193323Sed    }
1484207618Srdivacky
1485207618Srdivacky    if (Ops.size() == 1) return Ops[0];
1486193323Sed  }
1487193323Sed
1488212904Sdim  // Okay, check to see if the same value occurs in the operand list more than
1489212904Sdim  // once.  If so, merge them together into an multiply expression.  Since we
1490212904Sdim  // sorted the list, these values are required to be adjacent.
1491193323Sed  const Type *Ty = Ops[0]->getType();
1492212904Sdim  bool FoundMatch = false;
1493212904Sdim  for (unsigned i = 0, e = Ops.size(); i != e-1; ++i)
1494193323Sed    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1495212904Sdim      // Scan ahead to count how many equal operands there are.
1496212904Sdim      unsigned Count = 2;
1497212904Sdim      while (i+Count != e && Ops[i+Count] == Ops[i])
1498212904Sdim        ++Count;
1499212904Sdim      // Merge the values into a multiply.
1500212904Sdim      const SCEV *Scale = getConstant(Ty, Count);
1501212904Sdim      const SCEV *Mul = getMulExpr(Scale, Ops[i]);
1502212904Sdim      if (Ops.size() == Count)
1503193323Sed        return Mul;
1504212904Sdim      Ops[i] = Mul;
1505212904Sdim      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+Count);
1506212904Sdim      --i; e -= Count - 1;
1507212904Sdim      FoundMatch = true;
1508193323Sed    }
1509212904Sdim  if (FoundMatch)
1510221345Sdim    return getAddExpr(Ops, Flags);
1511193323Sed
1512193323Sed  // Check for truncates. If all the operands are truncated from the same
1513193323Sed  // type, see if factoring out the truncate would permit the result to be
1514193323Sed  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1515193323Sed  // if the contents of the resulting outer trunc fold to something simple.
1516193323Sed  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1517193323Sed    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1518193323Sed    const Type *DstType = Trunc->getType();
1519193323Sed    const Type *SrcType = Trunc->getOperand()->getType();
1520198090Srdivacky    SmallVector<const SCEV *, 8> LargeOps;
1521193323Sed    bool Ok = true;
1522193323Sed    // Check all the operands to see if they can be represented in the
1523193323Sed    // source type of the truncate.
1524193323Sed    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1525193323Sed      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1526193323Sed        if (T->getOperand()->getType() != SrcType) {
1527193323Sed          Ok = false;
1528193323Sed          break;
1529193323Sed        }
1530193323Sed        LargeOps.push_back(T->getOperand());
1531193323Sed      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1532207618Srdivacky        LargeOps.push_back(getAnyExtendExpr(C, SrcType));
1533193323Sed      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1534198090Srdivacky        SmallVector<const SCEV *, 8> LargeMulOps;
1535193323Sed        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1536193323Sed          if (const SCEVTruncateExpr *T =
1537193323Sed                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1538193323Sed            if (T->getOperand()->getType() != SrcType) {
1539193323Sed              Ok = false;
1540193323Sed              break;
1541193323Sed            }
1542193323Sed            LargeMulOps.push_back(T->getOperand());
1543193323Sed          } else if (const SCEVConstant *C =
1544193323Sed                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1545207618Srdivacky            LargeMulOps.push_back(getAnyExtendExpr(C, SrcType));
1546193323Sed          } else {
1547193323Sed            Ok = false;
1548193323Sed            break;
1549193323Sed          }
1550193323Sed        }
1551193323Sed        if (Ok)
1552193323Sed          LargeOps.push_back(getMulExpr(LargeMulOps));
1553193323Sed      } else {
1554193323Sed        Ok = false;
1555193323Sed        break;
1556193323Sed      }
1557193323Sed    }
1558193323Sed    if (Ok) {
1559193323Sed      // Evaluate the expression in the larger type.
1560221345Sdim      const SCEV *Fold = getAddExpr(LargeOps, Flags);
1561193323Sed      // If it folds to something simple, use it. Otherwise, don't.
1562193323Sed      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1563193323Sed        return getTruncateExpr(Fold, DstType);
1564193323Sed    }
1565193323Sed  }
1566193323Sed
1567193323Sed  // Skip past any other cast SCEVs.
1568193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1569193323Sed    ++Idx;
1570193323Sed
1571193323Sed  // If there are add operands they would be next.
1572193323Sed  if (Idx < Ops.size()) {
1573193323Sed    bool DeletedAdd = false;
1574193323Sed    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1575193323Sed      // If we have an add, expand the add operands onto the end of the operands
1576193323Sed      // list.
1577193323Sed      Ops.erase(Ops.begin()+Idx);
1578210299Sed      Ops.append(Add->op_begin(), Add->op_end());
1579193323Sed      DeletedAdd = true;
1580193323Sed    }
1581193323Sed
1582193323Sed    // If we deleted at least one add, we added operands to the end of the list,
1583193323Sed    // and they are not necessarily sorted.  Recurse to resort and resimplify
1584204642Srdivacky    // any operands we just acquired.
1585193323Sed    if (DeletedAdd)
1586193323Sed      return getAddExpr(Ops);
1587193323Sed  }
1588193323Sed
1589193323Sed  // Skip over the add expression until we get to a multiply.
1590193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1591193323Sed    ++Idx;
1592193323Sed
1593194612Sed  // Check to see if there are any folding opportunities present with
1594194612Sed  // operands multiplied by constant values.
1595194612Sed  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1596194612Sed    uint64_t BitWidth = getTypeSizeInBits(Ty);
1597198090Srdivacky    DenseMap<const SCEV *, APInt> M;
1598198090Srdivacky    SmallVector<const SCEV *, 8> NewOps;
1599194612Sed    APInt AccumulatedConstant(BitWidth, 0);
1600194612Sed    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1601205407Srdivacky                                     Ops.data(), Ops.size(),
1602205407Srdivacky                                     APInt(BitWidth, 1), *this)) {
1603194612Sed      // Some interesting folding opportunity is present, so its worthwhile to
1604194612Sed      // re-generate the operands list. Group the operands by constant scale,
1605194612Sed      // to avoid multiplying by the same constant scale multiple times.
1606198090Srdivacky      std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare> MulOpLists;
1607212904Sdim      for (SmallVector<const SCEV *, 8>::const_iterator I = NewOps.begin(),
1608194612Sed           E = NewOps.end(); I != E; ++I)
1609194612Sed        MulOpLists[M.find(*I)->second].push_back(*I);
1610194612Sed      // Re-generate the operands list.
1611194612Sed      Ops.clear();
1612194612Sed      if (AccumulatedConstant != 0)
1613194612Sed        Ops.push_back(getConstant(AccumulatedConstant));
1614195098Sed      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1615195098Sed           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1616194612Sed        if (I->first != 0)
1617195098Sed          Ops.push_back(getMulExpr(getConstant(I->first),
1618195098Sed                                   getAddExpr(I->second)));
1619194612Sed      if (Ops.empty())
1620207618Srdivacky        return getConstant(Ty, 0);
1621194612Sed      if (Ops.size() == 1)
1622194612Sed        return Ops[0];
1623194612Sed      return getAddExpr(Ops);
1624194612Sed    }
1625194612Sed  }
1626194612Sed
1627193323Sed  // If we are adding something to a multiply expression, make sure the
1628193323Sed  // something is not already an operand of the multiply.  If so, merge it into
1629193323Sed  // the multiply.
1630193323Sed  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1631193323Sed    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1632193323Sed    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1633193323Sed      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1634212904Sdim      if (isa<SCEVConstant>(MulOpSCEV))
1635212904Sdim        continue;
1636193323Sed      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1637212904Sdim        if (MulOpSCEV == Ops[AddOp]) {
1638193323Sed          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1639198090Srdivacky          const SCEV *InnerMul = Mul->getOperand(MulOp == 0);
1640193323Sed          if (Mul->getNumOperands() != 2) {
1641193323Sed            // If the multiply has more than two operands, we must get the
1642193323Sed            // Y*Z term.
1643212904Sdim            SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1644212904Sdim                                                Mul->op_begin()+MulOp);
1645212904Sdim            MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1646193323Sed            InnerMul = getMulExpr(MulOps);
1647193323Sed          }
1648207618Srdivacky          const SCEV *One = getConstant(Ty, 1);
1649212904Sdim          const SCEV *AddOne = getAddExpr(One, InnerMul);
1650212904Sdim          const SCEV *OuterMul = getMulExpr(AddOne, MulOpSCEV);
1651193323Sed          if (Ops.size() == 2) return OuterMul;
1652193323Sed          if (AddOp < Idx) {
1653193323Sed            Ops.erase(Ops.begin()+AddOp);
1654193323Sed            Ops.erase(Ops.begin()+Idx-1);
1655193323Sed          } else {
1656193323Sed            Ops.erase(Ops.begin()+Idx);
1657193323Sed            Ops.erase(Ops.begin()+AddOp-1);
1658193323Sed          }
1659193323Sed          Ops.push_back(OuterMul);
1660193323Sed          return getAddExpr(Ops);
1661193323Sed        }
1662193323Sed
1663193323Sed      // Check this multiply against other multiplies being added together.
1664193323Sed      for (unsigned OtherMulIdx = Idx+1;
1665193323Sed           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1666193323Sed           ++OtherMulIdx) {
1667193323Sed        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1668193323Sed        // If MulOp occurs in OtherMul, we can fold the two multiplies
1669193323Sed        // together.
1670193323Sed        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1671193323Sed             OMulOp != e; ++OMulOp)
1672193323Sed          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1673193323Sed            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1674198090Srdivacky            const SCEV *InnerMul1 = Mul->getOperand(MulOp == 0);
1675193323Sed            if (Mul->getNumOperands() != 2) {
1676195098Sed              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1677212904Sdim                                                  Mul->op_begin()+MulOp);
1678212904Sdim              MulOps.append(Mul->op_begin()+MulOp+1, Mul->op_end());
1679193323Sed              InnerMul1 = getMulExpr(MulOps);
1680193323Sed            }
1681198090Srdivacky            const SCEV *InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1682193323Sed            if (OtherMul->getNumOperands() != 2) {
1683195098Sed              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1684212904Sdim                                                  OtherMul->op_begin()+OMulOp);
1685212904Sdim              MulOps.append(OtherMul->op_begin()+OMulOp+1, OtherMul->op_end());
1686193323Sed              InnerMul2 = getMulExpr(MulOps);
1687193323Sed            }
1688198090Srdivacky            const SCEV *InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1689198090Srdivacky            const SCEV *OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1690193323Sed            if (Ops.size() == 2) return OuterMul;
1691193323Sed            Ops.erase(Ops.begin()+Idx);
1692193323Sed            Ops.erase(Ops.begin()+OtherMulIdx-1);
1693193323Sed            Ops.push_back(OuterMul);
1694193323Sed            return getAddExpr(Ops);
1695193323Sed          }
1696193323Sed      }
1697193323Sed    }
1698193323Sed  }
1699193323Sed
1700193323Sed  // If there are any add recurrences in the operands list, see if any other
1701193323Sed  // added values are loop invariant.  If so, we can fold them into the
1702193323Sed  // recurrence.
1703193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1704193323Sed    ++Idx;
1705193323Sed
1706193323Sed  // Scan over all recurrences, trying to fold loop invariants into them.
1707193323Sed  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1708193323Sed    // Scan all of the other operands to this add and add them to the vector if
1709193323Sed    // they are loop invariant w.r.t. the recurrence.
1710198090Srdivacky    SmallVector<const SCEV *, 8> LIOps;
1711193323Sed    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1712207618Srdivacky    const Loop *AddRecLoop = AddRec->getLoop();
1713193323Sed    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1714218893Sdim      if (isLoopInvariant(Ops[i], AddRecLoop)) {
1715193323Sed        LIOps.push_back(Ops[i]);
1716193323Sed        Ops.erase(Ops.begin()+i);
1717193323Sed        --i; --e;
1718193323Sed      }
1719193323Sed
1720193323Sed    // If we found some loop invariants, fold them into the recurrence.
1721193323Sed    if (!LIOps.empty()) {
1722193323Sed      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1723193323Sed      LIOps.push_back(AddRec->getStart());
1724193323Sed
1725198090Srdivacky      SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1726201360Srdivacky                                             AddRec->op_end());
1727193323Sed      AddRecOps[0] = getAddExpr(LIOps);
1728193323Sed
1729210299Sed      // Build the new addrec. Propagate the NUW and NSW flags if both the
1730210299Sed      // outer add and the inner addrec are guaranteed to have no overflow.
1731221345Sdim      // Always propagate NW.
1732221345Sdim      Flags = AddRec->getNoWrapFlags(setFlags(Flags, SCEV::FlagNW));
1733221345Sdim      const SCEV *NewRec = getAddRecExpr(AddRecOps, AddRecLoop, Flags);
1734201360Srdivacky
1735193323Sed      // If all of the other operands were loop invariant, we are done.
1736193323Sed      if (Ops.size() == 1) return NewRec;
1737193323Sed
1738193323Sed      // Otherwise, add the folded AddRec by the non-liv parts.
1739193323Sed      for (unsigned i = 0;; ++i)
1740193323Sed        if (Ops[i] == AddRec) {
1741193323Sed          Ops[i] = NewRec;
1742193323Sed          break;
1743193323Sed        }
1744193323Sed      return getAddExpr(Ops);
1745193323Sed    }
1746193323Sed
1747193323Sed    // Okay, if there weren't any loop invariants to be folded, check to see if
1748193323Sed    // there are multiple AddRec's with the same loop induction variable being
1749193323Sed    // added together.  If so, we can fold them.
1750193323Sed    for (unsigned OtherIdx = Idx+1;
1751212904Sdim         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1752212904Sdim         ++OtherIdx)
1753212904Sdim      if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
1754212904Sdim        // Other + {A,+,B}<L> + {C,+,D}<L>  -->  Other + {A+C,+,B+D}<L>
1755212904Sdim        SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
1756212904Sdim                                               AddRec->op_end());
1757212904Sdim        for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1758212904Sdim             ++OtherIdx)
1759212904Sdim          if (const SCEVAddRecExpr *OtherAddRec =
1760212904Sdim                dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
1761212904Sdim            if (OtherAddRec->getLoop() == AddRecLoop) {
1762212904Sdim              for (unsigned i = 0, e = OtherAddRec->getNumOperands();
1763212904Sdim                   i != e; ++i) {
1764212904Sdim                if (i >= AddRecOps.size()) {
1765212904Sdim                  AddRecOps.append(OtherAddRec->op_begin()+i,
1766212904Sdim                                   OtherAddRec->op_end());
1767212904Sdim                  break;
1768212904Sdim                }
1769212904Sdim                AddRecOps[i] = getAddExpr(AddRecOps[i],
1770212904Sdim                                          OtherAddRec->getOperand(i));
1771212904Sdim              }
1772212904Sdim              Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
1773193323Sed            }
1774221345Sdim        // Step size has changed, so we cannot guarantee no self-wraparound.
1775221345Sdim        Ops[Idx] = getAddRecExpr(AddRecOps, AddRecLoop, SCEV::FlagAnyWrap);
1776212904Sdim        return getAddExpr(Ops);
1777193323Sed      }
1778193323Sed
1779193323Sed    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1780193323Sed    // next one.
1781193323Sed  }
1782193323Sed
1783193323Sed  // Okay, it looks like we really DO need an add expr.  Check to see if we
1784193323Sed  // already have one, otherwise create a new one.
1785195340Sed  FoldingSetNodeID ID;
1786195340Sed  ID.AddInteger(scAddExpr);
1787195340Sed  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1788195340Sed    ID.AddPointer(Ops[i]);
1789195340Sed  void *IP = 0;
1790202878Srdivacky  SCEVAddExpr *S =
1791202878Srdivacky    static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
1792202878Srdivacky  if (!S) {
1793205407Srdivacky    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
1794205407Srdivacky    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
1795205407Srdivacky    S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator),
1796205407Srdivacky                                        O, Ops.size());
1797202878Srdivacky    UniqueSCEVs.InsertNode(S, IP);
1798202878Srdivacky  }
1799221345Sdim  S->setNoWrapFlags(Flags);
1800195340Sed  return S;
1801193323Sed}
1802193323Sed
1803193323Sed/// getMulExpr - Get a canonical multiply expression, or something simpler if
1804193323Sed/// possible.
1805198090Srdivackyconst SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
1806221345Sdim                                        SCEV::NoWrapFlags Flags) {
1807221345Sdim  assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
1808221345Sdim         "only nuw or nsw allowed");
1809193323Sed  assert(!Ops.empty() && "Cannot get empty mul!");
1810202878Srdivacky  if (Ops.size() == 1) return Ops[0];
1811193323Sed#ifndef NDEBUG
1812212904Sdim  const Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
1813193323Sed  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1814212904Sdim    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
1815193323Sed           "SCEVMulExpr operand types don't match!");
1816193323Sed#endif
1817193323Sed
1818221345Sdim  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
1819221345Sdim  // And vice-versa.
1820221345Sdim  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
1821221345Sdim  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
1822221345Sdim  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
1823202878Srdivacky    bool All = true;
1824212904Sdim    for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(),
1825212904Sdim         E = Ops.end(); I != E; ++I)
1826212904Sdim      if (!isKnownNonNegative(*I)) {
1827202878Srdivacky        All = false;
1828202878Srdivacky        break;
1829202878Srdivacky      }
1830221345Sdim    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
1831202878Srdivacky  }
1832202878Srdivacky
1833193323Sed  // Sort by complexity, this groups all similar expression types together.
1834193323Sed  GroupByComplexity(Ops, LI);
1835193323Sed
1836193323Sed  // If there are any constants, fold them together.
1837193323Sed  unsigned Idx = 0;
1838193323Sed  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1839193323Sed
1840193323Sed    // C1*(C2+V) -> C1*C2 + C1*V
1841193323Sed    if (Ops.size() == 2)
1842193323Sed      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1843193323Sed        if (Add->getNumOperands() == 2 &&
1844193323Sed            isa<SCEVConstant>(Add->getOperand(0)))
1845193323Sed          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1846193323Sed                            getMulExpr(LHSC, Add->getOperand(1)));
1847193323Sed
1848193323Sed    ++Idx;
1849193323Sed    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1850193323Sed      // We found two constants, fold them together!
1851198090Srdivacky      ConstantInt *Fold = ConstantInt::get(getContext(),
1852198090Srdivacky                                           LHSC->getValue()->getValue() *
1853193323Sed                                           RHSC->getValue()->getValue());
1854193323Sed      Ops[0] = getConstant(Fold);
1855193323Sed      Ops.erase(Ops.begin()+1);  // Erase the folded element
1856193323Sed      if (Ops.size() == 1) return Ops[0];
1857193323Sed      LHSC = cast<SCEVConstant>(Ops[0]);
1858193323Sed    }
1859193323Sed
1860193323Sed    // If we are left with a constant one being multiplied, strip it off.
1861193323Sed    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1862193323Sed      Ops.erase(Ops.begin());
1863193323Sed      --Idx;
1864193323Sed    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1865193323Sed      // If we have a multiply of zero, it will always be zero.
1866193323Sed      return Ops[0];
1867202878Srdivacky    } else if (Ops[0]->isAllOnesValue()) {
1868202878Srdivacky      // If we have a mul by -1 of an add, try distributing the -1 among the
1869202878Srdivacky      // add operands.
1870221345Sdim      if (Ops.size() == 2) {
1871202878Srdivacky        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
1872202878Srdivacky          SmallVector<const SCEV *, 4> NewOps;
1873202878Srdivacky          bool AnyFolded = false;
1874221345Sdim          for (SCEVAddRecExpr::op_iterator I = Add->op_begin(),
1875221345Sdim                 E = Add->op_end(); I != E; ++I) {
1876202878Srdivacky            const SCEV *Mul = getMulExpr(Ops[0], *I);
1877202878Srdivacky            if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
1878202878Srdivacky            NewOps.push_back(Mul);
1879202878Srdivacky          }
1880202878Srdivacky          if (AnyFolded)
1881202878Srdivacky            return getAddExpr(NewOps);
1882202878Srdivacky        }
1883221345Sdim        else if (const SCEVAddRecExpr *
1884221345Sdim                 AddRec = dyn_cast<SCEVAddRecExpr>(Ops[1])) {
1885221345Sdim          // Negation preserves a recurrence's no self-wrap property.
1886221345Sdim          SmallVector<const SCEV *, 4> Operands;
1887221345Sdim          for (SCEVAddRecExpr::op_iterator I = AddRec->op_begin(),
1888221345Sdim                 E = AddRec->op_end(); I != E; ++I) {
1889221345Sdim            Operands.push_back(getMulExpr(Ops[0], *I));
1890221345Sdim          }
1891221345Sdim          return getAddRecExpr(Operands, AddRec->getLoop(),
1892221345Sdim                               AddRec->getNoWrapFlags(SCEV::FlagNW));
1893221345Sdim        }
1894221345Sdim      }
1895193323Sed    }
1896207618Srdivacky
1897207618Srdivacky    if (Ops.size() == 1)
1898207618Srdivacky      return Ops[0];
1899193323Sed  }
1900193323Sed
1901193323Sed  // Skip over the add expression until we get to a multiply.
1902193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1903193323Sed    ++Idx;
1904193323Sed
1905193323Sed  // If there are mul operands inline them all into this expression.
1906193323Sed  if (Idx < Ops.size()) {
1907193323Sed    bool DeletedMul = false;
1908193323Sed    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1909193323Sed      // If we have an mul, expand the mul operands onto the end of the operands
1910193323Sed      // list.
1911193323Sed      Ops.erase(Ops.begin()+Idx);
1912210299Sed      Ops.append(Mul->op_begin(), Mul->op_end());
1913193323Sed      DeletedMul = true;
1914193323Sed    }
1915193323Sed
1916193323Sed    // If we deleted at least one mul, we added operands to the end of the list,
1917193323Sed    // and they are not necessarily sorted.  Recurse to resort and resimplify
1918204642Srdivacky    // any operands we just acquired.
1919193323Sed    if (DeletedMul)
1920193323Sed      return getMulExpr(Ops);
1921193323Sed  }
1922193323Sed
1923193323Sed  // If there are any add recurrences in the operands list, see if any other
1924193323Sed  // added values are loop invariant.  If so, we can fold them into the
1925193323Sed  // recurrence.
1926193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1927193323Sed    ++Idx;
1928193323Sed
1929193323Sed  // Scan over all recurrences, trying to fold loop invariants into them.
1930193323Sed  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1931193323Sed    // Scan all of the other operands to this mul and add them to the vector if
1932193323Sed    // they are loop invariant w.r.t. the recurrence.
1933198090Srdivacky    SmallVector<const SCEV *, 8> LIOps;
1934193323Sed    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1935212904Sdim    const Loop *AddRecLoop = AddRec->getLoop();
1936193323Sed    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1937218893Sdim      if (isLoopInvariant(Ops[i], AddRecLoop)) {
1938193323Sed        LIOps.push_back(Ops[i]);
1939193323Sed        Ops.erase(Ops.begin()+i);
1940193323Sed        --i; --e;
1941193323Sed      }
1942193323Sed
1943193323Sed    // If we found some loop invariants, fold them into the recurrence.
1944193323Sed    if (!LIOps.empty()) {
1945193323Sed      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1946198090Srdivacky      SmallVector<const SCEV *, 4> NewOps;
1947193323Sed      NewOps.reserve(AddRec->getNumOperands());
1948210299Sed      const SCEV *Scale = getMulExpr(LIOps);
1949210299Sed      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1950210299Sed        NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1951193323Sed
1952210299Sed      // Build the new addrec. Propagate the NUW and NSW flags if both the
1953210299Sed      // outer mul and the inner addrec are guaranteed to have no overflow.
1954221345Sdim      //
1955221345Sdim      // No self-wrap cannot be guaranteed after changing the step size, but
1956221345Sdim      // will be inferred if either NUW or NSW is true.
1957221345Sdim      Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
1958221345Sdim      const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
1959193323Sed
1960193323Sed      // If all of the other operands were loop invariant, we are done.
1961193323Sed      if (Ops.size() == 1) return NewRec;
1962193323Sed
1963193323Sed      // Otherwise, multiply the folded AddRec by the non-liv parts.
1964193323Sed      for (unsigned i = 0;; ++i)
1965193323Sed        if (Ops[i] == AddRec) {
1966193323Sed          Ops[i] = NewRec;
1967193323Sed          break;
1968193323Sed        }
1969193323Sed      return getMulExpr(Ops);
1970193323Sed    }
1971193323Sed
1972193323Sed    // Okay, if there weren't any loop invariants to be folded, check to see if
1973193323Sed    // there are multiple AddRec's with the same loop induction variable being
1974193323Sed    // multiplied together.  If so, we can fold them.
1975193323Sed    for (unsigned OtherIdx = Idx+1;
1976212904Sdim         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1977212904Sdim         ++OtherIdx)
1978212904Sdim      if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
1979212904Sdim        // F * G, where F = {A,+,B}<L> and G = {C,+,D}<L>  -->
1980212904Sdim        // {A*C,+,F*D + G*B + B*D}<L>
1981212904Sdim        for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
1982212904Sdim             ++OtherIdx)
1983212904Sdim          if (const SCEVAddRecExpr *OtherAddRec =
1984212904Sdim                dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]))
1985212904Sdim            if (OtherAddRec->getLoop() == AddRecLoop) {
1986212904Sdim              const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1987212904Sdim              const SCEV *NewStart = getMulExpr(F->getStart(), G->getStart());
1988212904Sdim              const SCEV *B = F->getStepRecurrence(*this);
1989212904Sdim              const SCEV *D = G->getStepRecurrence(*this);
1990212904Sdim              const SCEV *NewStep = getAddExpr(getMulExpr(F, D),
1991212904Sdim                                               getMulExpr(G, B),
1992212904Sdim                                               getMulExpr(B, D));
1993212904Sdim              const SCEV *NewAddRec = getAddRecExpr(NewStart, NewStep,
1994221345Sdim                                                    F->getLoop(),
1995221345Sdim                                                    SCEV::FlagAnyWrap);
1996212904Sdim              if (Ops.size() == 2) return NewAddRec;
1997212904Sdim              Ops[Idx] = AddRec = cast<SCEVAddRecExpr>(NewAddRec);
1998212904Sdim              Ops.erase(Ops.begin() + OtherIdx); --OtherIdx;
1999212904Sdim            }
2000212904Sdim        return getMulExpr(Ops);
2001193323Sed      }
2002193323Sed
2003193323Sed    // Otherwise couldn't fold anything into this recurrence.  Move onto the
2004193323Sed    // next one.
2005193323Sed  }
2006193323Sed
2007193323Sed  // Okay, it looks like we really DO need an mul expr.  Check to see if we
2008193323Sed  // already have one, otherwise create a new one.
2009195340Sed  FoldingSetNodeID ID;
2010195340Sed  ID.AddInteger(scMulExpr);
2011195340Sed  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2012195340Sed    ID.AddPointer(Ops[i]);
2013195340Sed  void *IP = 0;
2014202878Srdivacky  SCEVMulExpr *S =
2015202878Srdivacky    static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2016202878Srdivacky  if (!S) {
2017205407Srdivacky    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2018205407Srdivacky    std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2019205407Srdivacky    S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),
2020205407Srdivacky                                        O, Ops.size());
2021202878Srdivacky    UniqueSCEVs.InsertNode(S, IP);
2022202878Srdivacky  }
2023221345Sdim  S->setNoWrapFlags(Flags);
2024195340Sed  return S;
2025193323Sed}
2026193323Sed
2027198090Srdivacky/// getUDivExpr - Get a canonical unsigned division expression, or something
2028198090Srdivacky/// simpler if possible.
2029195098Sedconst SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
2030195098Sed                                         const SCEV *RHS) {
2031193323Sed  assert(getEffectiveSCEVType(LHS->getType()) ==
2032193323Sed         getEffectiveSCEVType(RHS->getType()) &&
2033193323Sed         "SCEVUDivExpr operand types don't match!");
2034193323Sed
2035193323Sed  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
2036193323Sed    if (RHSC->getValue()->equalsInt(1))
2037198090Srdivacky      return LHS;                               // X udiv 1 --> x
2038207618Srdivacky    // If the denominator is zero, the result of the udiv is undefined. Don't
2039207618Srdivacky    // try to analyze it, because the resolution chosen here may differ from
2040207618Srdivacky    // the resolution chosen in other parts of the compiler.
2041207618Srdivacky    if (!RHSC->getValue()->isZero()) {
2042207618Srdivacky      // Determine if the division can be folded into the operands of
2043207618Srdivacky      // its operands.
2044207618Srdivacky      // TODO: Generalize this to non-constants by using known-bits information.
2045207618Srdivacky      const Type *Ty = LHS->getType();
2046207618Srdivacky      unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
2047212904Sdim      unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ - 1;
2048207618Srdivacky      // For non-power-of-two values, effectively round the value up to the
2049207618Srdivacky      // nearest power of two.
2050207618Srdivacky      if (!RHSC->getValue()->getValue().isPowerOf2())
2051207618Srdivacky        ++MaxShiftAmt;
2052207618Srdivacky      const IntegerType *ExtTy =
2053207618Srdivacky        IntegerType::get(getContext(), getTypeSizeInBits(Ty) + MaxShiftAmt);
2054207618Srdivacky      // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
2055207618Srdivacky      if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
2056207618Srdivacky        if (const SCEVConstant *Step =
2057207618Srdivacky              dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
2058207618Srdivacky          if (!Step->getValue()->getValue()
2059207618Srdivacky                .urem(RHSC->getValue()->getValue()) &&
2060207618Srdivacky              getZeroExtendExpr(AR, ExtTy) ==
2061207618Srdivacky              getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
2062207618Srdivacky                            getZeroExtendExpr(Step, ExtTy),
2063221345Sdim                            AR->getLoop(), SCEV::FlagAnyWrap)) {
2064207618Srdivacky            SmallVector<const SCEV *, 4> Operands;
2065207618Srdivacky            for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
2066207618Srdivacky              Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
2067221345Sdim            return getAddRecExpr(Operands, AR->getLoop(),
2068221345Sdim                                 SCEV::FlagNW);
2069193323Sed          }
2070207618Srdivacky      // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
2071207618Srdivacky      if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
2072207618Srdivacky        SmallVector<const SCEV *, 4> Operands;
2073207618Srdivacky        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
2074207618Srdivacky          Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
2075207618Srdivacky        if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
2076207618Srdivacky          // Find an operand that's safely divisible.
2077207618Srdivacky          for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
2078207618Srdivacky            const SCEV *Op = M->getOperand(i);
2079207618Srdivacky            const SCEV *Div = getUDivExpr(Op, RHSC);
2080207618Srdivacky            if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
2081207618Srdivacky              Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
2082207618Srdivacky                                                      M->op_end());
2083207618Srdivacky              Operands[i] = Div;
2084207618Srdivacky              return getMulExpr(Operands);
2085207618Srdivacky            }
2086207618Srdivacky          }
2087207618Srdivacky      }
2088207618Srdivacky      // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
2089221345Sdim      if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
2090207618Srdivacky        SmallVector<const SCEV *, 4> Operands;
2091207618Srdivacky        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
2092207618Srdivacky          Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
2093207618Srdivacky        if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
2094207618Srdivacky          Operands.clear();
2095207618Srdivacky          for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2096207618Srdivacky            const SCEV *Op = getUDivExpr(A->getOperand(i), RHS);
2097207618Srdivacky            if (isa<SCEVUDivExpr>(Op) ||
2098207618Srdivacky                getMulExpr(Op, RHS) != A->getOperand(i))
2099207618Srdivacky              break;
2100207618Srdivacky            Operands.push_back(Op);
2101207618Srdivacky          }
2102207618Srdivacky          if (Operands.size() == A->getNumOperands())
2103207618Srdivacky            return getAddExpr(Operands);
2104193323Sed        }
2105193323Sed      }
2106193323Sed
2107207618Srdivacky      // Fold if both operands are constant.
2108207618Srdivacky      if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
2109207618Srdivacky        Constant *LHSCV = LHSC->getValue();
2110207618Srdivacky        Constant *RHSCV = RHSC->getValue();
2111207618Srdivacky        return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
2112207618Srdivacky                                                                   RHSCV)));
2113207618Srdivacky      }
2114193323Sed    }
2115193323Sed  }
2116193323Sed
2117195340Sed  FoldingSetNodeID ID;
2118195340Sed  ID.AddInteger(scUDivExpr);
2119195340Sed  ID.AddPointer(LHS);
2120195340Sed  ID.AddPointer(RHS);
2121195340Sed  void *IP = 0;
2122195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2123205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),
2124205407Srdivacky                                             LHS, RHS);
2125195340Sed  UniqueSCEVs.InsertNode(S, IP);
2126195340Sed  return S;
2127193323Sed}
2128193323Sed
2129193323Sed
2130193323Sed/// getAddRecExpr - Get an add recurrence expression for the specified loop.
2131193323Sed/// Simplify the expression as much as possible.
2132221345Sdimconst SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step,
2133221345Sdim                                           const Loop *L,
2134221345Sdim                                           SCEV::NoWrapFlags Flags) {
2135198090Srdivacky  SmallVector<const SCEV *, 4> Operands;
2136193323Sed  Operands.push_back(Start);
2137193323Sed  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
2138193323Sed    if (StepChrec->getLoop() == L) {
2139210299Sed      Operands.append(StepChrec->op_begin(), StepChrec->op_end());
2140221345Sdim      return getAddRecExpr(Operands, L, maskFlags(Flags, SCEV::FlagNW));
2141193323Sed    }
2142193323Sed
2143193323Sed  Operands.push_back(Step);
2144221345Sdim  return getAddRecExpr(Operands, L, Flags);
2145193323Sed}
2146193323Sed
2147193323Sed/// getAddRecExpr - Get an add recurrence expression for the specified loop.
2148193323Sed/// Simplify the expression as much as possible.
2149195098Sedconst SCEV *
2150198090SrdivackyScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
2151221345Sdim                               const Loop *L, SCEV::NoWrapFlags Flags) {
2152193323Sed  if (Operands.size() == 1) return Operands[0];
2153193323Sed#ifndef NDEBUG
2154212904Sdim  const Type *ETy = getEffectiveSCEVType(Operands[0]->getType());
2155193323Sed  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
2156212904Sdim    assert(getEffectiveSCEVType(Operands[i]->getType()) == ETy &&
2157193323Sed           "SCEVAddRecExpr operand types don't match!");
2158218893Sdim  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2159218893Sdim    assert(isLoopInvariant(Operands[i], L) &&
2160218893Sdim           "SCEVAddRecExpr operand is not loop-invariant!");
2161193323Sed#endif
2162193323Sed
2163193323Sed  if (Operands.back()->isZero()) {
2164193323Sed    Operands.pop_back();
2165221345Sdim    return getAddRecExpr(Operands, L, SCEV::FlagAnyWrap); // {X,+,0}  -->  X
2166193323Sed  }
2167193323Sed
2168204642Srdivacky  // It's tempting to want to call getMaxBackedgeTakenCount count here and
2169204642Srdivacky  // use that information to infer NUW and NSW flags. However, computing a
2170204642Srdivacky  // BE count requires calling getAddRecExpr, so we may not yet have a
2171204642Srdivacky  // meaningful BE count at this point (and if we don't, we'd be stuck
2172204642Srdivacky  // with a SCEVCouldNotCompute as the cached BE count).
2173204642Srdivacky
2174221345Sdim  // If FlagNSW is true and all the operands are non-negative, infer FlagNUW.
2175221345Sdim  // And vice-versa.
2176221345Sdim  int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW;
2177221345Sdim  SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask);
2178221345Sdim  if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) {
2179202878Srdivacky    bool All = true;
2180212904Sdim    for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(),
2181212904Sdim         E = Operands.end(); I != E; ++I)
2182212904Sdim      if (!isKnownNonNegative(*I)) {
2183202878Srdivacky        All = false;
2184202878Srdivacky        break;
2185202878Srdivacky      }
2186221345Sdim    if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask);
2187202878Srdivacky  }
2188202878Srdivacky
2189193323Sed  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
2190193323Sed  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
2191201360Srdivacky    const Loop *NestedLoop = NestedAR->getLoop();
2192212904Sdim    if (L->contains(NestedLoop) ?
2193202878Srdivacky        (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
2194212904Sdim        (!NestedLoop->contains(L) &&
2195202878Srdivacky         DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
2196198090Srdivacky      SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
2197201360Srdivacky                                                  NestedAR->op_end());
2198193323Sed      Operands[0] = NestedAR->getStart();
2199195098Sed      // AddRecs require their operands be loop-invariant with respect to their
2200195098Sed      // loops. Don't perform this transformation if it would break this
2201195098Sed      // requirement.
2202195098Sed      bool AllInvariant = true;
2203195098Sed      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2204218893Sdim        if (!isLoopInvariant(Operands[i], L)) {
2205195098Sed          AllInvariant = false;
2206195098Sed          break;
2207195098Sed        }
2208195098Sed      if (AllInvariant) {
2209221345Sdim        // Create a recurrence for the outer loop with the same step size.
2210221345Sdim        //
2211221345Sdim        // The outer recurrence keeps its NW flag but only keeps NUW/NSW if the
2212221345Sdim        // inner recurrence has the same property.
2213221345Sdim        SCEV::NoWrapFlags OuterFlags =
2214221345Sdim          maskFlags(Flags, SCEV::FlagNW | NestedAR->getNoWrapFlags());
2215221345Sdim
2216221345Sdim        NestedOperands[0] = getAddRecExpr(Operands, L, OuterFlags);
2217195098Sed        AllInvariant = true;
2218195098Sed        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
2219218893Sdim          if (!isLoopInvariant(NestedOperands[i], NestedLoop)) {
2220195098Sed            AllInvariant = false;
2221195098Sed            break;
2222195098Sed          }
2223221345Sdim        if (AllInvariant) {
2224195098Sed          // Ok, both add recurrences are valid after the transformation.
2225221345Sdim          //
2226221345Sdim          // The inner recurrence keeps its NW flag but only keeps NUW/NSW if
2227221345Sdim          // the outer recurrence has the same property.
2228221345Sdim          SCEV::NoWrapFlags InnerFlags =
2229221345Sdim            maskFlags(NestedAR->getNoWrapFlags(), SCEV::FlagNW | Flags);
2230221345Sdim          return getAddRecExpr(NestedOperands, NestedLoop, InnerFlags);
2231221345Sdim        }
2232195098Sed      }
2233195098Sed      // Reset Operands to its original state.
2234195098Sed      Operands[0] = NestedAR;
2235193323Sed    }
2236193323Sed  }
2237193323Sed
2238202878Srdivacky  // Okay, it looks like we really DO need an addrec expr.  Check to see if we
2239202878Srdivacky  // already have one, otherwise create a new one.
2240195340Sed  FoldingSetNodeID ID;
2241195340Sed  ID.AddInteger(scAddRecExpr);
2242195340Sed  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
2243195340Sed    ID.AddPointer(Operands[i]);
2244195340Sed  ID.AddPointer(L);
2245195340Sed  void *IP = 0;
2246202878Srdivacky  SCEVAddRecExpr *S =
2247202878Srdivacky    static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
2248202878Srdivacky  if (!S) {
2249205407Srdivacky    const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
2250205407Srdivacky    std::uninitialized_copy(Operands.begin(), Operands.end(), O);
2251205407Srdivacky    S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
2252205407Srdivacky                                           O, Operands.size(), L);
2253202878Srdivacky    UniqueSCEVs.InsertNode(S, IP);
2254202878Srdivacky  }
2255221345Sdim  S->setNoWrapFlags(Flags);
2256195340Sed  return S;
2257193323Sed}
2258193323Sed
2259195098Sedconst SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
2260195098Sed                                         const SCEV *RHS) {
2261198090Srdivacky  SmallVector<const SCEV *, 2> Ops;
2262193323Sed  Ops.push_back(LHS);
2263193323Sed  Ops.push_back(RHS);
2264193323Sed  return getSMaxExpr(Ops);
2265193323Sed}
2266193323Sed
2267198090Srdivackyconst SCEV *
2268198090SrdivackyScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2269193323Sed  assert(!Ops.empty() && "Cannot get empty smax!");
2270193323Sed  if (Ops.size() == 1) return Ops[0];
2271193323Sed#ifndef NDEBUG
2272212904Sdim  const Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2273193323Sed  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2274212904Sdim    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2275193323Sed           "SCEVSMaxExpr operand types don't match!");
2276193323Sed#endif
2277193323Sed
2278193323Sed  // Sort by complexity, this groups all similar expression types together.
2279193323Sed  GroupByComplexity(Ops, LI);
2280193323Sed
2281193323Sed  // If there are any constants, fold them together.
2282193323Sed  unsigned Idx = 0;
2283193323Sed  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2284193323Sed    ++Idx;
2285193323Sed    assert(Idx < Ops.size());
2286193323Sed    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2287193323Sed      // We found two constants, fold them together!
2288198090Srdivacky      ConstantInt *Fold = ConstantInt::get(getContext(),
2289193323Sed                              APIntOps::smax(LHSC->getValue()->getValue(),
2290193323Sed                                             RHSC->getValue()->getValue()));
2291193323Sed      Ops[0] = getConstant(Fold);
2292193323Sed      Ops.erase(Ops.begin()+1);  // Erase the folded element
2293193323Sed      if (Ops.size() == 1) return Ops[0];
2294193323Sed      LHSC = cast<SCEVConstant>(Ops[0]);
2295193323Sed    }
2296193323Sed
2297195098Sed    // If we are left with a constant minimum-int, strip it off.
2298193323Sed    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
2299193323Sed      Ops.erase(Ops.begin());
2300193323Sed      --Idx;
2301195098Sed    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
2302195098Sed      // If we have an smax with a constant maximum-int, it will always be
2303195098Sed      // maximum-int.
2304195098Sed      return Ops[0];
2305193323Sed    }
2306207618Srdivacky
2307207618Srdivacky    if (Ops.size() == 1) return Ops[0];
2308193323Sed  }
2309193323Sed
2310193323Sed  // Find the first SMax
2311193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
2312193323Sed    ++Idx;
2313193323Sed
2314193323Sed  // Check to see if one of the operands is an SMax. If so, expand its operands
2315193323Sed  // onto our operand list, and recurse to simplify.
2316193323Sed  if (Idx < Ops.size()) {
2317193323Sed    bool DeletedSMax = false;
2318193323Sed    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
2319193323Sed      Ops.erase(Ops.begin()+Idx);
2320210299Sed      Ops.append(SMax->op_begin(), SMax->op_end());
2321193323Sed      DeletedSMax = true;
2322193323Sed    }
2323193323Sed
2324193323Sed    if (DeletedSMax)
2325193323Sed      return getSMaxExpr(Ops);
2326193323Sed  }
2327193323Sed
2328193323Sed  // Okay, check to see if the same value occurs in the operand list twice.  If
2329193323Sed  // so, delete one.  Since we sorted the list, these values are required to
2330193323Sed  // be adjacent.
2331193323Sed  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2332207618Srdivacky    //  X smax Y smax Y  -->  X smax Y
2333207618Srdivacky    //  X smax Y         -->  X, if X is always greater than Y
2334207618Srdivacky    if (Ops[i] == Ops[i+1] ||
2335207618Srdivacky        isKnownPredicate(ICmpInst::ICMP_SGE, Ops[i], Ops[i+1])) {
2336207618Srdivacky      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2337207618Srdivacky      --i; --e;
2338207618Srdivacky    } else if (isKnownPredicate(ICmpInst::ICMP_SLE, Ops[i], Ops[i+1])) {
2339193323Sed      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2340193323Sed      --i; --e;
2341193323Sed    }
2342193323Sed
2343193323Sed  if (Ops.size() == 1) return Ops[0];
2344193323Sed
2345193323Sed  assert(!Ops.empty() && "Reduced smax down to nothing!");
2346193323Sed
2347193323Sed  // Okay, it looks like we really DO need an smax expr.  Check to see if we
2348193323Sed  // already have one, otherwise create a new one.
2349195340Sed  FoldingSetNodeID ID;
2350195340Sed  ID.AddInteger(scSMaxExpr);
2351195340Sed  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2352195340Sed    ID.AddPointer(Ops[i]);
2353195340Sed  void *IP = 0;
2354195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2355205407Srdivacky  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2356205407Srdivacky  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2357205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),
2358205407Srdivacky                                             O, Ops.size());
2359195340Sed  UniqueSCEVs.InsertNode(S, IP);
2360195340Sed  return S;
2361193323Sed}
2362193323Sed
2363195098Sedconst SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
2364195098Sed                                         const SCEV *RHS) {
2365198090Srdivacky  SmallVector<const SCEV *, 2> Ops;
2366193323Sed  Ops.push_back(LHS);
2367193323Sed  Ops.push_back(RHS);
2368193323Sed  return getUMaxExpr(Ops);
2369193323Sed}
2370193323Sed
2371198090Srdivackyconst SCEV *
2372198090SrdivackyScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
2373193323Sed  assert(!Ops.empty() && "Cannot get empty umax!");
2374193323Sed  if (Ops.size() == 1) return Ops[0];
2375193323Sed#ifndef NDEBUG
2376212904Sdim  const Type *ETy = getEffectiveSCEVType(Ops[0]->getType());
2377193323Sed  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
2378212904Sdim    assert(getEffectiveSCEVType(Ops[i]->getType()) == ETy &&
2379193323Sed           "SCEVUMaxExpr operand types don't match!");
2380193323Sed#endif
2381193323Sed
2382193323Sed  // Sort by complexity, this groups all similar expression types together.
2383193323Sed  GroupByComplexity(Ops, LI);
2384193323Sed
2385193323Sed  // If there are any constants, fold them together.
2386193323Sed  unsigned Idx = 0;
2387193323Sed  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
2388193323Sed    ++Idx;
2389193323Sed    assert(Idx < Ops.size());
2390193323Sed    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
2391193323Sed      // We found two constants, fold them together!
2392198090Srdivacky      ConstantInt *Fold = ConstantInt::get(getContext(),
2393193323Sed                              APIntOps::umax(LHSC->getValue()->getValue(),
2394193323Sed                                             RHSC->getValue()->getValue()));
2395193323Sed      Ops[0] = getConstant(Fold);
2396193323Sed      Ops.erase(Ops.begin()+1);  // Erase the folded element
2397193323Sed      if (Ops.size() == 1) return Ops[0];
2398193323Sed      LHSC = cast<SCEVConstant>(Ops[0]);
2399193323Sed    }
2400193323Sed
2401195098Sed    // If we are left with a constant minimum-int, strip it off.
2402193323Sed    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
2403193323Sed      Ops.erase(Ops.begin());
2404193323Sed      --Idx;
2405195098Sed    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
2406195098Sed      // If we have an umax with a constant maximum-int, it will always be
2407195098Sed      // maximum-int.
2408195098Sed      return Ops[0];
2409193323Sed    }
2410207618Srdivacky
2411207618Srdivacky    if (Ops.size() == 1) return Ops[0];
2412193323Sed  }
2413193323Sed
2414193323Sed  // Find the first UMax
2415193323Sed  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
2416193323Sed    ++Idx;
2417193323Sed
2418193323Sed  // Check to see if one of the operands is a UMax. If so, expand its operands
2419193323Sed  // onto our operand list, and recurse to simplify.
2420193323Sed  if (Idx < Ops.size()) {
2421193323Sed    bool DeletedUMax = false;
2422193323Sed    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
2423193323Sed      Ops.erase(Ops.begin()+Idx);
2424210299Sed      Ops.append(UMax->op_begin(), UMax->op_end());
2425193323Sed      DeletedUMax = true;
2426193323Sed    }
2427193323Sed
2428193323Sed    if (DeletedUMax)
2429193323Sed      return getUMaxExpr(Ops);
2430193323Sed  }
2431193323Sed
2432193323Sed  // Okay, check to see if the same value occurs in the operand list twice.  If
2433193323Sed  // so, delete one.  Since we sorted the list, these values are required to
2434193323Sed  // be adjacent.
2435193323Sed  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
2436207618Srdivacky    //  X umax Y umax Y  -->  X umax Y
2437207618Srdivacky    //  X umax Y         -->  X, if X is always greater than Y
2438207618Srdivacky    if (Ops[i] == Ops[i+1] ||
2439207618Srdivacky        isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
2440207618Srdivacky      Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
2441207618Srdivacky      --i; --e;
2442207618Srdivacky    } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
2443193323Sed      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
2444193323Sed      --i; --e;
2445193323Sed    }
2446193323Sed
2447193323Sed  if (Ops.size() == 1) return Ops[0];
2448193323Sed
2449193323Sed  assert(!Ops.empty() && "Reduced umax down to nothing!");
2450193323Sed
2451193323Sed  // Okay, it looks like we really DO need a umax expr.  Check to see if we
2452193323Sed  // already have one, otherwise create a new one.
2453195340Sed  FoldingSetNodeID ID;
2454195340Sed  ID.AddInteger(scUMaxExpr);
2455195340Sed  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
2456195340Sed    ID.AddPointer(Ops[i]);
2457195340Sed  void *IP = 0;
2458195340Sed  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
2459205407Srdivacky  const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
2460205407Srdivacky  std::uninitialized_copy(Ops.begin(), Ops.end(), O);
2461205407Srdivacky  SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),
2462205407Srdivacky                                             O, Ops.size());
2463195340Sed  UniqueSCEVs.InsertNode(S, IP);
2464195340Sed  return S;
2465193323Sed}
2466193323Sed
2467195098Sedconst SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
2468195098Sed                                         const SCEV *RHS) {
2469194612Sed  // ~smax(~x, ~y) == smin(x, y).
2470194612Sed  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2471194612Sed}
2472194612Sed
2473195098Sedconst SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
2474195098Sed                                         const SCEV *RHS) {
2475194612Sed  // ~umax(~x, ~y) == umin(x, y)
2476194612Sed  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
2477194612Sed}
2478194612Sed
2479203954Srdivackyconst SCEV *ScalarEvolution::getSizeOfExpr(const Type *AllocTy) {
2480207618Srdivacky  // If we have TargetData, we can bypass creating a target-independent
2481207618Srdivacky  // constant expression and then folding it back into a ConstantInt.
2482207618Srdivacky  // This is just a compile-time optimization.
2483207618Srdivacky  if (TD)
2484207618Srdivacky    return getConstant(TD->getIntPtrType(getContext()),
2485207618Srdivacky                       TD->getTypeAllocSize(AllocTy));
2486207618Srdivacky
2487203954Srdivacky  Constant *C = ConstantExpr::getSizeOf(AllocTy);
2488203954Srdivacky  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2489210299Sed    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2490210299Sed      C = Folded;
2491203954Srdivacky  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2492203954Srdivacky  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2493203954Srdivacky}
2494198090Srdivacky
2495203954Srdivackyconst SCEV *ScalarEvolution::getAlignOfExpr(const Type *AllocTy) {
2496203954Srdivacky  Constant *C = ConstantExpr::getAlignOf(AllocTy);
2497203954Srdivacky  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2498210299Sed    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2499210299Sed      C = Folded;
2500203954Srdivacky  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(AllocTy));
2501203954Srdivacky  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2502203954Srdivacky}
2503198090Srdivacky
2504203954Srdivackyconst SCEV *ScalarEvolution::getOffsetOfExpr(const StructType *STy,
2505203954Srdivacky                                             unsigned FieldNo) {
2506207618Srdivacky  // If we have TargetData, we can bypass creating a target-independent
2507207618Srdivacky  // constant expression and then folding it back into a ConstantInt.
2508207618Srdivacky  // This is just a compile-time optimization.
2509207618Srdivacky  if (TD)
2510207618Srdivacky    return getConstant(TD->getIntPtrType(getContext()),
2511207618Srdivacky                       TD->getStructLayout(STy)->getElementOffset(FieldNo));
2512207618Srdivacky
2513203954Srdivacky  Constant *C = ConstantExpr::getOffsetOf(STy, FieldNo);
2514203954Srdivacky  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2515210299Sed    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2516210299Sed      C = Folded;
2517198090Srdivacky  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(STy));
2518203954Srdivacky  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2519198090Srdivacky}
2520198090Srdivacky
2521203954Srdivackyconst SCEV *ScalarEvolution::getOffsetOfExpr(const Type *CTy,
2522203954Srdivacky                                             Constant *FieldNo) {
2523203954Srdivacky  Constant *C = ConstantExpr::getOffsetOf(CTy, FieldNo);
2524203954Srdivacky  if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C))
2525210299Sed    if (Constant *Folded = ConstantFoldConstantExpression(CE, TD))
2526210299Sed      C = Folded;
2527203954Srdivacky  const Type *Ty = getEffectiveSCEVType(PointerType::getUnqual(CTy));
2528203954Srdivacky  return getTruncateOrZeroExtend(getSCEV(C), Ty);
2529198090Srdivacky}
2530198090Srdivacky
2531198090Srdivackyconst SCEV *ScalarEvolution::getUnknown(Value *V) {
2532195098Sed  // Don't attempt to do anything other than create a SCEVUnknown object
2533195098Sed  // here.  createSCEV only calls getUnknown after checking for all other
2534195098Sed  // interesting possibilities, and any other code that calls getUnknown
2535195098Sed  // is doing so in order to hide a value from SCEV canonicalization.
2536195098Sed
2537195340Sed  FoldingSetNodeID ID;
2538195340Sed  ID.AddInteger(scUnknown);
2539195340Sed  ID.AddPointer(V);
2540195340Sed  void *IP = 0;
2541212904Sdim  if (SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) {
2542212904Sdim    assert(cast<SCEVUnknown>(S)->getValue() == V &&
2543212904Sdim           "Stale SCEVUnknown in uniquing map!");
2544212904Sdim    return S;
2545212904Sdim  }
2546212904Sdim  SCEV *S = new (SCEVAllocator) SCEVUnknown(ID.Intern(SCEVAllocator), V, this,
2547212904Sdim                                            FirstUnknown);
2548212904Sdim  FirstUnknown = cast<SCEVUnknown>(S);
2549195340Sed  UniqueSCEVs.InsertNode(S, IP);
2550195340Sed  return S;
2551193323Sed}
2552193323Sed
2553193323Sed//===----------------------------------------------------------------------===//
2554193323Sed//            Basic SCEV Analysis and PHI Idiom Recognition Code
2555193323Sed//
2556193323Sed
2557193323Sed/// isSCEVable - Test if values of the given type are analyzable within
2558193323Sed/// the SCEV framework. This primarily includes integer types, and it
2559193323Sed/// can optionally include pointer types if the ScalarEvolution class
2560193323Sed/// has access to target-specific information.
2561193323Sedbool ScalarEvolution::isSCEVable(const Type *Ty) const {
2562198090Srdivacky  // Integers and pointers are always SCEVable.
2563204642Srdivacky  return Ty->isIntegerTy() || Ty->isPointerTy();
2564193323Sed}
2565193323Sed
2566193323Sed/// getTypeSizeInBits - Return the size in bits of the specified type,
2567193323Sed/// for which isSCEVable must return true.
2568193323Seduint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
2569193323Sed  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2570193323Sed
2571193323Sed  // If we have a TargetData, use it!
2572193323Sed  if (TD)
2573193323Sed    return TD->getTypeSizeInBits(Ty);
2574193323Sed
2575198090Srdivacky  // Integer types have fixed sizes.
2576203954Srdivacky  if (Ty->isIntegerTy())
2577198090Srdivacky    return Ty->getPrimitiveSizeInBits();
2578198090Srdivacky
2579198090Srdivacky  // The only other support type is pointer. Without TargetData, conservatively
2580198090Srdivacky  // assume pointers are 64-bit.
2581204642Srdivacky  assert(Ty->isPointerTy() && "isSCEVable permitted a non-SCEVable type!");
2582198090Srdivacky  return 64;
2583193323Sed}
2584193323Sed
2585193323Sed/// getEffectiveSCEVType - Return a type with the same bitwidth as
2586193323Sed/// the given type and which represents how SCEV will treat the given
2587193323Sed/// type, for which isSCEVable must return true. For pointer types,
2588193323Sed/// this is the pointer-sized integer type.
2589193323Sedconst Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
2590193323Sed  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2591193323Sed
2592203954Srdivacky  if (Ty->isIntegerTy())
2593193323Sed    return Ty;
2594193323Sed
2595198090Srdivacky  // The only other support type is pointer.
2596204642Srdivacky  assert(Ty->isPointerTy() && "Unexpected non-pointer non-integer type!");
2597198090Srdivacky  if (TD) return TD->getIntPtrType(getContext());
2598198090Srdivacky
2599198090Srdivacky  // Without TargetData, conservatively assume pointers are 64-bit.
2600198090Srdivacky  return Type::getInt64Ty(getContext());
2601193323Sed}
2602193323Sed
2603198090Srdivackyconst SCEV *ScalarEvolution::getCouldNotCompute() {
2604195340Sed  return &CouldNotCompute;
2605193323Sed}
2606193323Sed
2607193323Sed/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2608193323Sed/// expression and create a new one.
2609198090Srdivackyconst SCEV *ScalarEvolution::getSCEV(Value *V) {
2610193323Sed  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2611193323Sed
2612212904Sdim  ValueExprMapType::const_iterator I = ValueExprMap.find(V);
2613212904Sdim  if (I != ValueExprMap.end()) return I->second;
2614198090Srdivacky  const SCEV *S = createSCEV(V);
2615212904Sdim
2616212904Sdim  // The process of creating a SCEV for V may have caused other SCEVs
2617212904Sdim  // to have been created, so it's necessary to insert the new entry
2618212904Sdim  // from scratch, rather than trying to remember the insert position
2619212904Sdim  // above.
2620212904Sdim  ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2621193323Sed  return S;
2622193323Sed}
2623193323Sed
2624193323Sed/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2625193323Sed///
2626198090Srdivackyconst SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V) {
2627193323Sed  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2628198090Srdivacky    return getConstant(
2629198090Srdivacky               cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2630193323Sed
2631193323Sed  const Type *Ty = V->getType();
2632193323Sed  Ty = getEffectiveSCEVType(Ty);
2633198090Srdivacky  return getMulExpr(V,
2634198090Srdivacky                  getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))));
2635193323Sed}
2636193323Sed
2637193323Sed/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2638198090Srdivackyconst SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
2639193323Sed  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2640198090Srdivacky    return getConstant(
2641198090Srdivacky                cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2642193323Sed
2643193323Sed  const Type *Ty = V->getType();
2644193323Sed  Ty = getEffectiveSCEVType(Ty);
2645198090Srdivacky  const SCEV *AllOnes =
2646198090Srdivacky                   getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
2647193323Sed  return getMinusSCEV(AllOnes, V);
2648193323Sed}
2649193323Sed
2650221345Sdim/// getMinusSCEV - Return LHS-RHS.  Minus is represented in SCEV as A+B*-1.
2651218893Sdimconst SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
2652221345Sdim                                          SCEV::NoWrapFlags Flags) {
2653221345Sdim  assert(!maskFlags(Flags, SCEV::FlagNUW) && "subtraction does not have NUW");
2654221345Sdim
2655212904Sdim  // Fast path: X - X --> 0.
2656212904Sdim  if (LHS == RHS)
2657212904Sdim    return getConstant(LHS->getType(), 0);
2658212904Sdim
2659193323Sed  // X - Y --> X + -Y
2660221345Sdim  return getAddExpr(LHS, getNegativeSCEV(RHS), Flags);
2661193323Sed}
2662193323Sed
2663193323Sed/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2664193323Sed/// input value to the specified type.  If the type must be extended, it is zero
2665193323Sed/// extended.
2666198090Srdivackyconst SCEV *
2667218893SdimScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, const Type *Ty) {
2668193323Sed  const Type *SrcTy = V->getType();
2669204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2670204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2671193323Sed         "Cannot truncate or zero extend with non-integer arguments!");
2672193323Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2673193323Sed    return V;  // No conversion
2674193323Sed  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2675193323Sed    return getTruncateExpr(V, Ty);
2676193323Sed  return getZeroExtendExpr(V, Ty);
2677193323Sed}
2678193323Sed
2679193323Sed/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2680193323Sed/// input value to the specified type.  If the type must be extended, it is sign
2681193323Sed/// extended.
2682198090Srdivackyconst SCEV *
2683198090SrdivackyScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
2684193323Sed                                         const Type *Ty) {
2685193323Sed  const Type *SrcTy = V->getType();
2686204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2687204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2688193323Sed         "Cannot truncate or zero extend with non-integer arguments!");
2689193323Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2690193323Sed    return V;  // No conversion
2691193323Sed  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2692193323Sed    return getTruncateExpr(V, Ty);
2693193323Sed  return getSignExtendExpr(V, Ty);
2694193323Sed}
2695193323Sed
2696193323Sed/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2697193323Sed/// input value to the specified type.  If the type must be extended, it is zero
2698193323Sed/// extended.  The conversion must not be narrowing.
2699198090Srdivackyconst SCEV *
2700198090SrdivackyScalarEvolution::getNoopOrZeroExtend(const SCEV *V, const Type *Ty) {
2701193323Sed  const Type *SrcTy = V->getType();
2702204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2703204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2704193323Sed         "Cannot noop or zero extend with non-integer arguments!");
2705193323Sed  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2706193323Sed         "getNoopOrZeroExtend cannot truncate!");
2707193323Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2708193323Sed    return V;  // No conversion
2709193323Sed  return getZeroExtendExpr(V, Ty);
2710193323Sed}
2711193323Sed
2712193323Sed/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2713193323Sed/// input value to the specified type.  If the type must be extended, it is sign
2714193323Sed/// extended.  The conversion must not be narrowing.
2715198090Srdivackyconst SCEV *
2716198090SrdivackyScalarEvolution::getNoopOrSignExtend(const SCEV *V, const Type *Ty) {
2717193323Sed  const Type *SrcTy = V->getType();
2718204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2719204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2720193323Sed         "Cannot noop or sign extend with non-integer arguments!");
2721193323Sed  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2722193323Sed         "getNoopOrSignExtend cannot truncate!");
2723193323Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2724193323Sed    return V;  // No conversion
2725193323Sed  return getSignExtendExpr(V, Ty);
2726193323Sed}
2727193323Sed
2728194178Sed/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2729194178Sed/// the input value to the specified type. If the type must be extended,
2730194178Sed/// it is extended with unspecified bits. The conversion must not be
2731194178Sed/// narrowing.
2732198090Srdivackyconst SCEV *
2733198090SrdivackyScalarEvolution::getNoopOrAnyExtend(const SCEV *V, const Type *Ty) {
2734194178Sed  const Type *SrcTy = V->getType();
2735204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2736204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2737194178Sed         "Cannot noop or any extend with non-integer arguments!");
2738194178Sed  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2739194178Sed         "getNoopOrAnyExtend cannot truncate!");
2740194178Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2741194178Sed    return V;  // No conversion
2742194178Sed  return getAnyExtendExpr(V, Ty);
2743194178Sed}
2744194178Sed
2745193323Sed/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2746193323Sed/// input value to the specified type.  The conversion must not be widening.
2747198090Srdivackyconst SCEV *
2748198090SrdivackyScalarEvolution::getTruncateOrNoop(const SCEV *V, const Type *Ty) {
2749193323Sed  const Type *SrcTy = V->getType();
2750204642Srdivacky  assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
2751204642Srdivacky         (Ty->isIntegerTy() || Ty->isPointerTy()) &&
2752193323Sed         "Cannot truncate or noop with non-integer arguments!");
2753193323Sed  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2754193323Sed         "getTruncateOrNoop cannot extend!");
2755193323Sed  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2756193323Sed    return V;  // No conversion
2757193323Sed  return getTruncateExpr(V, Ty);
2758193323Sed}
2759193323Sed
2760194612Sed/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2761194612Sed/// the types using zero-extension, and then perform a umax operation
2762194612Sed/// with them.
2763195098Sedconst SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2764195098Sed                                                        const SCEV *RHS) {
2765198090Srdivacky  const SCEV *PromotedLHS = LHS;
2766198090Srdivacky  const SCEV *PromotedRHS = RHS;
2767194612Sed
2768194612Sed  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2769194612Sed    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2770194612Sed  else
2771194612Sed    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2772194612Sed
2773194612Sed  return getUMaxExpr(PromotedLHS, PromotedRHS);
2774194612Sed}
2775194612Sed
2776194710Sed/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2777194710Sed/// the types using zero-extension, and then perform a umin operation
2778194710Sed/// with them.
2779195098Sedconst SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2780195098Sed                                                        const SCEV *RHS) {
2781198090Srdivacky  const SCEV *PromotedLHS = LHS;
2782198090Srdivacky  const SCEV *PromotedRHS = RHS;
2783194710Sed
2784194710Sed  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2785194710Sed    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2786194710Sed  else
2787194710Sed    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2788194710Sed
2789194710Sed  return getUMinExpr(PromotedLHS, PromotedRHS);
2790194710Sed}
2791194710Sed
2792221345Sdim/// getPointerBase - Transitively follow the chain of pointer-type operands
2793221345Sdim/// until reaching a SCEV that does not have a single pointer operand. This
2794221345Sdim/// returns a SCEVUnknown pointer for well-formed pointer-type expressions,
2795221345Sdim/// but corner cases do exist.
2796221345Sdimconst SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
2797221345Sdim  // A pointer operand may evaluate to a nonpointer expression, such as null.
2798221345Sdim  if (!V->getType()->isPointerTy())
2799221345Sdim    return V;
2800221345Sdim
2801221345Sdim  if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
2802221345Sdim    return getPointerBase(Cast->getOperand());
2803221345Sdim  }
2804221345Sdim  else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
2805221345Sdim    const SCEV *PtrOp = 0;
2806221345Sdim    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
2807221345Sdim         I != E; ++I) {
2808221345Sdim      if ((*I)->getType()->isPointerTy()) {
2809221345Sdim        // Cannot find the base of an expression with multiple pointer operands.
2810221345Sdim        if (PtrOp)
2811221345Sdim          return V;
2812221345Sdim        PtrOp = *I;
2813221345Sdim      }
2814221345Sdim    }
2815221345Sdim    if (!PtrOp)
2816221345Sdim      return V;
2817221345Sdim    return getPointerBase(PtrOp);
2818221345Sdim  }
2819221345Sdim  return V;
2820221345Sdim}
2821221345Sdim
2822198090Srdivacky/// PushDefUseChildren - Push users of the given Instruction
2823198090Srdivacky/// onto the given Worklist.
2824198090Srdivackystatic void
2825198090SrdivackyPushDefUseChildren(Instruction *I,
2826198090Srdivacky                   SmallVectorImpl<Instruction *> &Worklist) {
2827198090Srdivacky  // Push the def-use children onto the Worklist stack.
2828198090Srdivacky  for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2829198090Srdivacky       UI != UE; ++UI)
2830212904Sdim    Worklist.push_back(cast<Instruction>(*UI));
2831198090Srdivacky}
2832198090Srdivacky
2833198090Srdivacky/// ForgetSymbolicValue - This looks up computed SCEV values for all
2834198090Srdivacky/// instructions that depend on the given instruction and removes them from
2835212904Sdim/// the ValueExprMapType map if they reference SymName. This is used during PHI
2836198090Srdivacky/// resolution.
2837195098Sedvoid
2838204642SrdivackyScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) {
2839198090Srdivacky  SmallVector<Instruction *, 16> Worklist;
2840204642Srdivacky  PushDefUseChildren(PN, Worklist);
2841193323Sed
2842198090Srdivacky  SmallPtrSet<Instruction *, 8> Visited;
2843204642Srdivacky  Visited.insert(PN);
2844198090Srdivacky  while (!Worklist.empty()) {
2845198090Srdivacky    Instruction *I = Worklist.pop_back_val();
2846198090Srdivacky    if (!Visited.insert(I)) continue;
2847193323Sed
2848212904Sdim    ValueExprMapType::iterator It =
2849212904Sdim      ValueExprMap.find(static_cast<Value *>(I));
2850212904Sdim    if (It != ValueExprMap.end()) {
2851218893Sdim      const SCEV *Old = It->second;
2852218893Sdim
2853198090Srdivacky      // Short-circuit the def-use traversal if the symbolic name
2854198090Srdivacky      // ceases to appear in expressions.
2855218893Sdim      if (Old != SymName && !hasOperand(Old, SymName))
2856198090Srdivacky        continue;
2857193323Sed
2858198090Srdivacky      // SCEVUnknown for a PHI either means that it has an unrecognized
2859204642Srdivacky      // structure, it's a PHI that's in the progress of being computed
2860204642Srdivacky      // by createNodeForPHI, or it's a single-value PHI. In the first case,
2861204642Srdivacky      // additional loop trip count information isn't going to change anything.
2862204642Srdivacky      // In the second case, createNodeForPHI will perform the necessary
2863204642Srdivacky      // updates on its own when it gets to that point. In the third, we do
2864204642Srdivacky      // want to forget the SCEVUnknown.
2865204642Srdivacky      if (!isa<PHINode>(I) ||
2866218893Sdim          !isa<SCEVUnknown>(Old) ||
2867218893Sdim          (I != PN && Old == SymName)) {
2868218893Sdim        forgetMemoizedResults(Old);
2869212904Sdim        ValueExprMap.erase(It);
2870198090Srdivacky      }
2871198090Srdivacky    }
2872198090Srdivacky
2873198090Srdivacky    PushDefUseChildren(I, Worklist);
2874198090Srdivacky  }
2875193323Sed}
2876193323Sed
2877193323Sed/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2878193323Sed/// a loop header, making it a potential recurrence, or it doesn't.
2879193323Sed///
2880198090Srdivackyconst SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) {
2881207618Srdivacky  if (const Loop *L = LI->getLoopFor(PN->getParent()))
2882207618Srdivacky    if (L->getHeader() == PN->getParent()) {
2883207618Srdivacky      // The loop may have multiple entrances or multiple exits; we can analyze
2884207618Srdivacky      // this phi as an addrec if it has a unique entry value and a unique
2885207618Srdivacky      // backedge value.
2886207618Srdivacky      Value *BEValueV = 0, *StartValueV = 0;
2887207618Srdivacky      for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) {
2888207618Srdivacky        Value *V = PN->getIncomingValue(i);
2889207618Srdivacky        if (L->contains(PN->getIncomingBlock(i))) {
2890207618Srdivacky          if (!BEValueV) {
2891207618Srdivacky            BEValueV = V;
2892207618Srdivacky          } else if (BEValueV != V) {
2893207618Srdivacky            BEValueV = 0;
2894207618Srdivacky            break;
2895207618Srdivacky          }
2896207618Srdivacky        } else if (!StartValueV) {
2897207618Srdivacky          StartValueV = V;
2898207618Srdivacky        } else if (StartValueV != V) {
2899207618Srdivacky          StartValueV = 0;
2900207618Srdivacky          break;
2901207618Srdivacky        }
2902207618Srdivacky      }
2903207618Srdivacky      if (BEValueV && StartValueV) {
2904193323Sed        // While we are analyzing this PHI node, handle its value symbolically.
2905198090Srdivacky        const SCEV *SymbolicName = getUnknown(PN);
2906212904Sdim        assert(ValueExprMap.find(PN) == ValueExprMap.end() &&
2907193323Sed               "PHI node already processed?");
2908212904Sdim        ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
2909193323Sed
2910193323Sed        // Using this symbolic name for the PHI, analyze the value coming around
2911193323Sed        // the back-edge.
2912198090Srdivacky        const SCEV *BEValue = getSCEV(BEValueV);
2913193323Sed
2914193323Sed        // NOTE: If BEValue is loop invariant, we know that the PHI node just
2915193323Sed        // has a special value for the first iteration of the loop.
2916193323Sed
2917193323Sed        // If the value coming around the backedge is an add with the symbolic
2918193323Sed        // value we just inserted, then we found a simple induction variable!
2919193323Sed        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
2920193323Sed          // If there is a single occurrence of the symbolic value, replace it
2921193323Sed          // with a recurrence.
2922193323Sed          unsigned FoundIndex = Add->getNumOperands();
2923193323Sed          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2924193323Sed            if (Add->getOperand(i) == SymbolicName)
2925193323Sed              if (FoundIndex == e) {
2926193323Sed                FoundIndex = i;
2927193323Sed                break;
2928193323Sed              }
2929193323Sed
2930193323Sed          if (FoundIndex != Add->getNumOperands()) {
2931193323Sed            // Create an add with everything but the specified operand.
2932198090Srdivacky            SmallVector<const SCEV *, 8> Ops;
2933193323Sed            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2934193323Sed              if (i != FoundIndex)
2935193323Sed                Ops.push_back(Add->getOperand(i));
2936198090Srdivacky            const SCEV *Accum = getAddExpr(Ops);
2937193323Sed
2938193323Sed            // This is not a valid addrec if the step amount is varying each
2939193323Sed            // loop iteration, but is not itself an addrec in this loop.
2940218893Sdim            if (isLoopInvariant(Accum, L) ||
2941193323Sed                (isa<SCEVAddRecExpr>(Accum) &&
2942193323Sed                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
2943221345Sdim              SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;
2944202878Srdivacky
2945202878Srdivacky              // If the increment doesn't overflow, then neither the addrec nor
2946202878Srdivacky              // the post-increment will overflow.
2947202878Srdivacky              if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
2948202878Srdivacky                if (OBO->hasNoUnsignedWrap())
2949221345Sdim                  Flags = setFlags(Flags, SCEV::FlagNUW);
2950202878Srdivacky                if (OBO->hasNoSignedWrap())
2951221345Sdim                  Flags = setFlags(Flags, SCEV::FlagNSW);
2952221345Sdim              } else if (const GEPOperator *GEP =
2953221345Sdim                         dyn_cast<GEPOperator>(BEValueV)) {
2954221345Sdim                // If the increment is an inbounds GEP, then we know the address
2955221345Sdim                // space cannot be wrapped around. We cannot make any guarantee
2956221345Sdim                // about signed or unsigned overflow because pointers are
2957221345Sdim                // unsigned but we may have a negative index from the base
2958221345Sdim                // pointer.
2959221345Sdim                if (GEP->isInBounds())
2960221345Sdim                  Flags = setFlags(Flags, SCEV::FlagNW);
2961202878Srdivacky              }
2962202878Srdivacky
2963207618Srdivacky              const SCEV *StartVal = getSCEV(StartValueV);
2964221345Sdim              const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags);
2965193323Sed
2966202878Srdivacky              // Since the no-wrap flags are on the increment, they apply to the
2967202878Srdivacky              // post-incremented value as well.
2968218893Sdim              if (isLoopInvariant(Accum, L))
2969202878Srdivacky                (void)getAddRecExpr(getAddExpr(StartVal, Accum),
2970221345Sdim                                    Accum, L, Flags);
2971198090Srdivacky
2972193323Sed              // Okay, for the entire analysis of this edge we assumed the PHI
2973198090Srdivacky              // to be symbolic.  We now need to go back and purge all of the
2974198090Srdivacky              // entries for the scalars that use the symbolic expression.
2975198090Srdivacky              ForgetSymbolicName(PN, SymbolicName);
2976212904Sdim              ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
2977193323Sed              return PHISCEV;
2978193323Sed            }
2979193323Sed          }
2980193323Sed        } else if (const SCEVAddRecExpr *AddRec =
2981193323Sed                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
2982193323Sed          // Otherwise, this could be a loop like this:
2983193323Sed          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2984193323Sed          // In this case, j = {1,+,1}  and BEValue is j.
2985193323Sed          // Because the other in-value of i (0) fits the evolution of BEValue
2986193323Sed          // i really is an addrec evolution.
2987193323Sed          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2988207618Srdivacky            const SCEV *StartVal = getSCEV(StartValueV);
2989193323Sed
2990193323Sed            // If StartVal = j.start - j.stride, we can use StartVal as the
2991193323Sed            // initial step of the addrec evolution.
2992193323Sed            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2993207618Srdivacky                                         AddRec->getOperand(1))) {
2994221345Sdim              // FIXME: For constant StartVal, we should be able to infer
2995221345Sdim              // no-wrap flags.
2996198090Srdivacky              const SCEV *PHISCEV =
2997221345Sdim                getAddRecExpr(StartVal, AddRec->getOperand(1), L,
2998221345Sdim                              SCEV::FlagAnyWrap);
2999193323Sed
3000193323Sed              // Okay, for the entire analysis of this edge we assumed the PHI
3001198090Srdivacky              // to be symbolic.  We now need to go back and purge all of the
3002198090Srdivacky              // entries for the scalars that use the symbolic expression.
3003198090Srdivacky              ForgetSymbolicName(PN, SymbolicName);
3004212904Sdim              ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV;
3005193323Sed              return PHISCEV;
3006193323Sed            }
3007193323Sed          }
3008193323Sed        }
3009193323Sed      }
3010207618Srdivacky    }
3011193323Sed
3012204642Srdivacky  // If the PHI has a single incoming value, follow that value, unless the
3013204642Srdivacky  // PHI's incoming blocks are in a different loop, in which case doing so
3014204642Srdivacky  // risks breaking LCSSA form. Instcombine would normally zap these, but
3015204642Srdivacky  // it doesn't have DominatorTree information, so it may miss cases.
3016218893Sdim  if (Value *V = SimplifyInstruction(PN, TD, DT))
3017218893Sdim    if (LI->replacementPreservesLCSSAForm(PN, V))
3018204642Srdivacky      return getSCEV(V);
3019198090Srdivacky
3020193323Sed  // If it's not a loop phi, we can't handle it yet.
3021193323Sed  return getUnknown(PN);
3022193323Sed}
3023193323Sed
3024193323Sed/// createNodeForGEP - Expand GEP instructions into add and multiply
3025193323Sed/// operations. This allows them to be analyzed by regular SCEV code.
3026193323Sed///
3027201360Srdivackyconst SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {
3028193323Sed
3029210299Sed  // Don't blindly transfer the inbounds flag from the GEP instruction to the
3030210299Sed  // Add expression, because the Instruction may be guarded by control flow
3031210299Sed  // and the no-overflow bits may not be valid for the expression in any
3032210299Sed  // context.
3033218893Sdim  bool isInBounds = GEP->isInBounds();
3034210299Sed
3035198090Srdivacky  const Type *IntPtrTy = getEffectiveSCEVType(GEP->getType());
3036193323Sed  Value *Base = GEP->getOperand(0);
3037193323Sed  // Don't attempt to analyze GEPs over unsized objects.
3038193323Sed  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
3039193323Sed    return getUnknown(GEP);
3040207618Srdivacky  const SCEV *TotalOffset = getConstant(IntPtrTy, 0);
3041193323Sed  gep_type_iterator GTI = gep_type_begin(GEP);
3042212904Sdim  for (GetElementPtrInst::op_iterator I = llvm::next(GEP->op_begin()),
3043193323Sed                                      E = GEP->op_end();
3044193323Sed       I != E; ++I) {
3045193323Sed    Value *Index = *I;
3046193323Sed    // Compute the (potentially symbolic) offset in bytes for this index.
3047193323Sed    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
3048193323Sed      // For a struct, add the member offset.
3049193323Sed      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
3050210299Sed      const SCEV *FieldOffset = getOffsetOfExpr(STy, FieldNo);
3051210299Sed
3052210299Sed      // Add the field offset to the running total offset.
3053210299Sed      TotalOffset = getAddExpr(TotalOffset, FieldOffset);
3054193323Sed    } else {
3055193323Sed      // For an array, add the element offset, explicitly scaled.
3056210299Sed      const SCEV *ElementSize = getSizeOfExpr(*GTI);
3057210299Sed      const SCEV *IndexS = getSCEV(Index);
3058204642Srdivacky      // Getelementptr indices are signed.
3059210299Sed      IndexS = getTruncateOrSignExtend(IndexS, IntPtrTy);
3060210299Sed
3061210299Sed      // Multiply the index by the element size to compute the element offset.
3062221345Sdim      const SCEV *LocalOffset = getMulExpr(IndexS, ElementSize,
3063221345Sdim                                           isInBounds ? SCEV::FlagNSW :
3064221345Sdim                                           SCEV::FlagAnyWrap);
3065210299Sed
3066210299Sed      // Add the element offset to the running total offset.
3067210299Sed      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
3068193323Sed    }
3069193323Sed  }
3070210299Sed
3071210299Sed  // Get the SCEV for the GEP base.
3072210299Sed  const SCEV *BaseS = getSCEV(Base);
3073210299Sed
3074210299Sed  // Add the total offset from all the GEP indices to the base.
3075221345Sdim  return getAddExpr(BaseS, TotalOffset,
3076221345Sdim                    isInBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap);
3077193323Sed}
3078193323Sed
3079193323Sed/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
3080193323Sed/// guaranteed to end in (at every loop iteration).  It is, at the same time,
3081193323Sed/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
3082193323Sed/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
3083194612Seduint32_t
3084198090SrdivackyScalarEvolution::GetMinTrailingZeros(const SCEV *S) {
3085193323Sed  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3086193323Sed    return C->getValue()->getValue().countTrailingZeros();
3087193323Sed
3088193323Sed  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
3089194612Sed    return std::min(GetMinTrailingZeros(T->getOperand()),
3090194612Sed                    (uint32_t)getTypeSizeInBits(T->getType()));
3091193323Sed
3092193323Sed  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
3093194612Sed    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3094194612Sed    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3095194612Sed             getTypeSizeInBits(E->getType()) : OpRes;
3096193323Sed  }
3097193323Sed
3098193323Sed  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
3099194612Sed    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
3100194612Sed    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
3101194612Sed             getTypeSizeInBits(E->getType()) : OpRes;
3102193323Sed  }
3103193323Sed
3104193323Sed  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
3105193323Sed    // The result is the min of all operands results.
3106194612Sed    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3107193323Sed    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3108194612Sed      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3109193323Sed    return MinOpRes;
3110193323Sed  }
3111193323Sed
3112193323Sed  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
3113193323Sed    // The result is the sum of all operands results.
3114194612Sed    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
3115194612Sed    uint32_t BitWidth = getTypeSizeInBits(M->getType());
3116193323Sed    for (unsigned i = 1, e = M->getNumOperands();
3117193323Sed         SumOpRes != BitWidth && i != e; ++i)
3118194612Sed      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
3119193323Sed                          BitWidth);
3120193323Sed    return SumOpRes;
3121193323Sed  }
3122193323Sed
3123193323Sed  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
3124193323Sed    // The result is the min of all operands results.
3125194612Sed    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
3126193323Sed    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
3127194612Sed      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
3128193323Sed    return MinOpRes;
3129193323Sed  }
3130193323Sed
3131193323Sed  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
3132193323Sed    // The result is the min of all operands results.
3133194612Sed    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3134193323Sed    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3135194612Sed      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3136193323Sed    return MinOpRes;
3137193323Sed  }
3138193323Sed
3139193323Sed  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
3140193323Sed    // The result is the min of all operands results.
3141194612Sed    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
3142193323Sed    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
3143194612Sed      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
3144193323Sed    return MinOpRes;
3145193323Sed  }
3146193323Sed
3147194612Sed  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3148194612Sed    // For a SCEVUnknown, ask ValueTracking.
3149194612Sed    unsigned BitWidth = getTypeSizeInBits(U->getType());
3150194612Sed    APInt Mask = APInt::getAllOnesValue(BitWidth);
3151194612Sed    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3152194612Sed    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
3153194612Sed    return Zeros.countTrailingOnes();
3154194612Sed  }
3155194612Sed
3156194612Sed  // SCEVUDivExpr
3157193323Sed  return 0;
3158193323Sed}
3159193323Sed
3160198090Srdivacky/// getUnsignedRange - Determine the unsigned range for a particular SCEV.
3161198090Srdivacky///
3162198090SrdivackyConstantRange
3163198090SrdivackyScalarEvolution::getUnsignedRange(const SCEV *S) {
3164218893Sdim  // See if we've computed this range already.
3165218893Sdim  DenseMap<const SCEV *, ConstantRange>::iterator I = UnsignedRanges.find(S);
3166218893Sdim  if (I != UnsignedRanges.end())
3167218893Sdim    return I->second;
3168194612Sed
3169194612Sed  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3170218893Sdim    return setUnsignedRange(C, ConstantRange(C->getValue()->getValue()));
3171194612Sed
3172203954Srdivacky  unsigned BitWidth = getTypeSizeInBits(S->getType());
3173203954Srdivacky  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3174203954Srdivacky
3175203954Srdivacky  // If the value has known zeros, the maximum unsigned value will have those
3176203954Srdivacky  // known zeros as well.
3177203954Srdivacky  uint32_t TZ = GetMinTrailingZeros(S);
3178203954Srdivacky  if (TZ != 0)
3179203954Srdivacky    ConservativeResult =
3180203954Srdivacky      ConstantRange(APInt::getMinValue(BitWidth),
3181203954Srdivacky                    APInt::getMaxValue(BitWidth).lshr(TZ).shl(TZ) + 1);
3182203954Srdivacky
3183198090Srdivacky  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3184198090Srdivacky    ConstantRange X = getUnsignedRange(Add->getOperand(0));
3185198090Srdivacky    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3186198090Srdivacky      X = X.add(getUnsignedRange(Add->getOperand(i)));
3187218893Sdim    return setUnsignedRange(Add, ConservativeResult.intersectWith(X));
3188194612Sed  }
3189194612Sed
3190198090Srdivacky  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3191198090Srdivacky    ConstantRange X = getUnsignedRange(Mul->getOperand(0));
3192198090Srdivacky    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3193198090Srdivacky      X = X.multiply(getUnsignedRange(Mul->getOperand(i)));
3194218893Sdim    return setUnsignedRange(Mul, ConservativeResult.intersectWith(X));
3195198090Srdivacky  }
3196198090Srdivacky
3197198090Srdivacky  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3198198090Srdivacky    ConstantRange X = getUnsignedRange(SMax->getOperand(0));
3199198090Srdivacky    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3200198090Srdivacky      X = X.smax(getUnsignedRange(SMax->getOperand(i)));
3201218893Sdim    return setUnsignedRange(SMax, ConservativeResult.intersectWith(X));
3202198090Srdivacky  }
3203198090Srdivacky
3204198090Srdivacky  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3205198090Srdivacky    ConstantRange X = getUnsignedRange(UMax->getOperand(0));
3206198090Srdivacky    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3207198090Srdivacky      X = X.umax(getUnsignedRange(UMax->getOperand(i)));
3208218893Sdim    return setUnsignedRange(UMax, ConservativeResult.intersectWith(X));
3209198090Srdivacky  }
3210198090Srdivacky
3211198090Srdivacky  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3212198090Srdivacky    ConstantRange X = getUnsignedRange(UDiv->getLHS());
3213198090Srdivacky    ConstantRange Y = getUnsignedRange(UDiv->getRHS());
3214218893Sdim    return setUnsignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3215198090Srdivacky  }
3216198090Srdivacky
3217198090Srdivacky  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3218198090Srdivacky    ConstantRange X = getUnsignedRange(ZExt->getOperand());
3219218893Sdim    return setUnsignedRange(ZExt,
3220218893Sdim      ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3221198090Srdivacky  }
3222198090Srdivacky
3223198090Srdivacky  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3224198090Srdivacky    ConstantRange X = getUnsignedRange(SExt->getOperand());
3225218893Sdim    return setUnsignedRange(SExt,
3226218893Sdim      ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3227198090Srdivacky  }
3228198090Srdivacky
3229198090Srdivacky  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3230198090Srdivacky    ConstantRange X = getUnsignedRange(Trunc->getOperand());
3231218893Sdim    return setUnsignedRange(Trunc,
3232218893Sdim      ConservativeResult.intersectWith(X.truncate(BitWidth)));
3233198090Srdivacky  }
3234198090Srdivacky
3235198090Srdivacky  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3236202878Srdivacky    // If there's no unsigned wrap, the value will never be less than its
3237202878Srdivacky    // initial value.
3238221345Sdim    if (AddRec->getNoWrapFlags(SCEV::FlagNUW))
3239202878Srdivacky      if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
3240207618Srdivacky        if (!C->getValue()->isZero())
3241207618Srdivacky          ConservativeResult =
3242210299Sed            ConservativeResult.intersectWith(
3243210299Sed              ConstantRange(C->getValue()->getValue(), APInt(BitWidth, 0)));
3244202878Srdivacky
3245198090Srdivacky    // TODO: non-affine addrec
3246203954Srdivacky    if (AddRec->isAffine()) {
3247198090Srdivacky      const Type *Ty = AddRec->getType();
3248198090Srdivacky      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3249203954Srdivacky      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3250203954Srdivacky          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3251198090Srdivacky        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3252198090Srdivacky
3253198090Srdivacky        const SCEV *Start = AddRec->getStart();
3254207618Srdivacky        const SCEV *Step = AddRec->getStepRecurrence(*this);
3255198090Srdivacky
3256207618Srdivacky        ConstantRange StartRange = getUnsignedRange(Start);
3257207618Srdivacky        ConstantRange StepRange = getSignedRange(Step);
3258207618Srdivacky        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3259207618Srdivacky        ConstantRange EndRange =
3260207618Srdivacky          StartRange.add(MaxBECountRange.multiply(StepRange));
3261207618Srdivacky
3262207618Srdivacky        // Check for overflow. This must be done with ConstantRange arithmetic
3263207618Srdivacky        // because we could be called from within the ScalarEvolution overflow
3264207618Srdivacky        // checking code.
3265207618Srdivacky        ConstantRange ExtStartRange = StartRange.zextOrTrunc(BitWidth*2+1);
3266207618Srdivacky        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3267207618Srdivacky        ConstantRange ExtMaxBECountRange =
3268207618Srdivacky          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3269207618Srdivacky        ConstantRange ExtEndRange = EndRange.zextOrTrunc(BitWidth*2+1);
3270207618Srdivacky        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3271207618Srdivacky            ExtEndRange)
3272218893Sdim          return setUnsignedRange(AddRec, ConservativeResult);
3273198090Srdivacky
3274198090Srdivacky        APInt Min = APIntOps::umin(StartRange.getUnsignedMin(),
3275198090Srdivacky                                   EndRange.getUnsignedMin());
3276198090Srdivacky        APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
3277198090Srdivacky                                   EndRange.getUnsignedMax());
3278198090Srdivacky        if (Min.isMinValue() && Max.isMaxValue())
3279218893Sdim          return setUnsignedRange(AddRec, ConservativeResult);
3280218893Sdim        return setUnsignedRange(AddRec,
3281218893Sdim          ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3282198090Srdivacky      }
3283198090Srdivacky    }
3284202878Srdivacky
3285218893Sdim    return setUnsignedRange(AddRec, ConservativeResult);
3286198090Srdivacky  }
3287198090Srdivacky
3288194612Sed  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3289194612Sed    // For a SCEVUnknown, ask ValueTracking.
3290194612Sed    APInt Mask = APInt::getAllOnesValue(BitWidth);
3291194612Sed    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
3292194612Sed    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
3293198090Srdivacky    if (Ones == ~Zeros + 1)
3294218893Sdim      return setUnsignedRange(U, ConservativeResult);
3295218893Sdim    return setUnsignedRange(U,
3296218893Sdim      ConservativeResult.intersectWith(ConstantRange(Ones, ~Zeros + 1)));
3297194612Sed  }
3298194612Sed
3299218893Sdim  return setUnsignedRange(S, ConservativeResult);
3300194612Sed}
3301194612Sed
3302198090Srdivacky/// getSignedRange - Determine the signed range for a particular SCEV.
3303198090Srdivacky///
3304198090SrdivackyConstantRange
3305198090SrdivackyScalarEvolution::getSignedRange(const SCEV *S) {
3306218893Sdim  // See if we've computed this range already.
3307218893Sdim  DenseMap<const SCEV *, ConstantRange>::iterator I = SignedRanges.find(S);
3308218893Sdim  if (I != SignedRanges.end())
3309218893Sdim    return I->second;
3310194612Sed
3311198090Srdivacky  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
3312218893Sdim    return setSignedRange(C, ConstantRange(C->getValue()->getValue()));
3313198090Srdivacky
3314203954Srdivacky  unsigned BitWidth = getTypeSizeInBits(S->getType());
3315203954Srdivacky  ConstantRange ConservativeResult(BitWidth, /*isFullSet=*/true);
3316203954Srdivacky
3317203954Srdivacky  // If the value has known zeros, the maximum signed value will have those
3318203954Srdivacky  // known zeros as well.
3319203954Srdivacky  uint32_t TZ = GetMinTrailingZeros(S);
3320203954Srdivacky  if (TZ != 0)
3321203954Srdivacky    ConservativeResult =
3322203954Srdivacky      ConstantRange(APInt::getSignedMinValue(BitWidth),
3323203954Srdivacky                    APInt::getSignedMaxValue(BitWidth).ashr(TZ).shl(TZ) + 1);
3324203954Srdivacky
3325198090Srdivacky  if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) {
3326198090Srdivacky    ConstantRange X = getSignedRange(Add->getOperand(0));
3327198090Srdivacky    for (unsigned i = 1, e = Add->getNumOperands(); i != e; ++i)
3328198090Srdivacky      X = X.add(getSignedRange(Add->getOperand(i)));
3329218893Sdim    return setSignedRange(Add, ConservativeResult.intersectWith(X));
3330194612Sed  }
3331194612Sed
3332198090Srdivacky  if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
3333198090Srdivacky    ConstantRange X = getSignedRange(Mul->getOperand(0));
3334198090Srdivacky    for (unsigned i = 1, e = Mul->getNumOperands(); i != e; ++i)
3335198090Srdivacky      X = X.multiply(getSignedRange(Mul->getOperand(i)));
3336218893Sdim    return setSignedRange(Mul, ConservativeResult.intersectWith(X));
3337194612Sed  }
3338194612Sed
3339198090Srdivacky  if (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(S)) {
3340198090Srdivacky    ConstantRange X = getSignedRange(SMax->getOperand(0));
3341198090Srdivacky    for (unsigned i = 1, e = SMax->getNumOperands(); i != e; ++i)
3342198090Srdivacky      X = X.smax(getSignedRange(SMax->getOperand(i)));
3343218893Sdim    return setSignedRange(SMax, ConservativeResult.intersectWith(X));
3344198090Srdivacky  }
3345195098Sed
3346198090Srdivacky  if (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(S)) {
3347198090Srdivacky    ConstantRange X = getSignedRange(UMax->getOperand(0));
3348198090Srdivacky    for (unsigned i = 1, e = UMax->getNumOperands(); i != e; ++i)
3349198090Srdivacky      X = X.umax(getSignedRange(UMax->getOperand(i)));
3350218893Sdim    return setSignedRange(UMax, ConservativeResult.intersectWith(X));
3351198090Srdivacky  }
3352195098Sed
3353198090Srdivacky  if (const SCEVUDivExpr *UDiv = dyn_cast<SCEVUDivExpr>(S)) {
3354198090Srdivacky    ConstantRange X = getSignedRange(UDiv->getLHS());
3355198090Srdivacky    ConstantRange Y = getSignedRange(UDiv->getRHS());
3356218893Sdim    return setSignedRange(UDiv, ConservativeResult.intersectWith(X.udiv(Y)));
3357198090Srdivacky  }
3358195098Sed
3359198090Srdivacky  if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) {
3360198090Srdivacky    ConstantRange X = getSignedRange(ZExt->getOperand());
3361218893Sdim    return setSignedRange(ZExt,
3362218893Sdim      ConservativeResult.intersectWith(X.zeroExtend(BitWidth)));
3363198090Srdivacky  }
3364198090Srdivacky
3365198090Srdivacky  if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) {
3366198090Srdivacky    ConstantRange X = getSignedRange(SExt->getOperand());
3367218893Sdim    return setSignedRange(SExt,
3368218893Sdim      ConservativeResult.intersectWith(X.signExtend(BitWidth)));
3369198090Srdivacky  }
3370198090Srdivacky
3371198090Srdivacky  if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
3372198090Srdivacky    ConstantRange X = getSignedRange(Trunc->getOperand());
3373218893Sdim    return setSignedRange(Trunc,
3374218893Sdim      ConservativeResult.intersectWith(X.truncate(BitWidth)));
3375198090Srdivacky  }
3376198090Srdivacky
3377198090Srdivacky  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
3378202878Srdivacky    // If there's no signed wrap, and all the operands have the same sign or
3379202878Srdivacky    // zero, the value won't ever change sign.
3380221345Sdim    if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) {
3381202878Srdivacky      bool AllNonNeg = true;
3382202878Srdivacky      bool AllNonPos = true;
3383202878Srdivacky      for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
3384202878Srdivacky        if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
3385202878Srdivacky        if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
3386202878Srdivacky      }
3387202878Srdivacky      if (AllNonNeg)
3388203954Srdivacky        ConservativeResult = ConservativeResult.intersectWith(
3389203954Srdivacky          ConstantRange(APInt(BitWidth, 0),
3390203954Srdivacky                        APInt::getSignedMinValue(BitWidth)));
3391202878Srdivacky      else if (AllNonPos)
3392203954Srdivacky        ConservativeResult = ConservativeResult.intersectWith(
3393203954Srdivacky          ConstantRange(APInt::getSignedMinValue(BitWidth),
3394203954Srdivacky                        APInt(BitWidth, 1)));
3395202878Srdivacky    }
3396202878Srdivacky
3397198090Srdivacky    // TODO: non-affine addrec
3398203954Srdivacky    if (AddRec->isAffine()) {
3399198090Srdivacky      const Type *Ty = AddRec->getType();
3400198090Srdivacky      const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
3401203954Srdivacky      if (!isa<SCEVCouldNotCompute>(MaxBECount) &&
3402203954Srdivacky          getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) {
3403198090Srdivacky        MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
3404198090Srdivacky
3405198090Srdivacky        const SCEV *Start = AddRec->getStart();
3406207618Srdivacky        const SCEV *Step = AddRec->getStepRecurrence(*this);
3407198090Srdivacky
3408207618Srdivacky        ConstantRange StartRange = getSignedRange(Start);
3409207618Srdivacky        ConstantRange StepRange = getSignedRange(Step);
3410207618Srdivacky        ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount);
3411207618Srdivacky        ConstantRange EndRange =
3412207618Srdivacky          StartRange.add(MaxBECountRange.multiply(StepRange));
3413207618Srdivacky
3414207618Srdivacky        // Check for overflow. This must be done with ConstantRange arithmetic
3415207618Srdivacky        // because we could be called from within the ScalarEvolution overflow
3416207618Srdivacky        // checking code.
3417207618Srdivacky        ConstantRange ExtStartRange = StartRange.sextOrTrunc(BitWidth*2+1);
3418207618Srdivacky        ConstantRange ExtStepRange = StepRange.sextOrTrunc(BitWidth*2+1);
3419207618Srdivacky        ConstantRange ExtMaxBECountRange =
3420207618Srdivacky          MaxBECountRange.zextOrTrunc(BitWidth*2+1);
3421207618Srdivacky        ConstantRange ExtEndRange = EndRange.sextOrTrunc(BitWidth*2+1);
3422207618Srdivacky        if (ExtStartRange.add(ExtMaxBECountRange.multiply(ExtStepRange)) !=
3423207618Srdivacky            ExtEndRange)
3424218893Sdim          return setSignedRange(AddRec, ConservativeResult);
3425198090Srdivacky
3426198090Srdivacky        APInt Min = APIntOps::smin(StartRange.getSignedMin(),
3427198090Srdivacky                                   EndRange.getSignedMin());
3428198090Srdivacky        APInt Max = APIntOps::smax(StartRange.getSignedMax(),
3429198090Srdivacky                                   EndRange.getSignedMax());
3430198090Srdivacky        if (Min.isMinSignedValue() && Max.isMaxSignedValue())
3431218893Sdim          return setSignedRange(AddRec, ConservativeResult);
3432218893Sdim        return setSignedRange(AddRec,
3433218893Sdim          ConservativeResult.intersectWith(ConstantRange(Min, Max+1)));
3434195098Sed      }
3435195098Sed    }
3436202878Srdivacky
3437218893Sdim    return setSignedRange(AddRec, ConservativeResult);
3438195098Sed  }
3439195098Sed
3440194612Sed  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
3441194612Sed    // For a SCEVUnknown, ask ValueTracking.
3442203954Srdivacky    if (!U->getValue()->getType()->isIntegerTy() && !TD)
3443218893Sdim      return setSignedRange(U, ConservativeResult);
3444198090Srdivacky    unsigned NS = ComputeNumSignBits(U->getValue(), TD);
3445198090Srdivacky    if (NS == 1)
3446218893Sdim      return setSignedRange(U, ConservativeResult);
3447218893Sdim    return setSignedRange(U, ConservativeResult.intersectWith(
3448198090Srdivacky      ConstantRange(APInt::getSignedMinValue(BitWidth).ashr(NS - 1),
3449218893Sdim                    APInt::getSignedMaxValue(BitWidth).ashr(NS - 1)+1)));
3450194612Sed  }
3451194612Sed
3452218893Sdim  return setSignedRange(S, ConservativeResult);
3453194612Sed}
3454194612Sed
3455193323Sed/// createSCEV - We know that there is no SCEV for the specified value.
3456193323Sed/// Analyze the expression.
3457193323Sed///
3458198090Srdivackyconst SCEV *ScalarEvolution::createSCEV(Value *V) {
3459193323Sed  if (!isSCEVable(V->getType()))
3460193323Sed    return getUnknown(V);
3461193323Sed
3462193323Sed  unsigned Opcode = Instruction::UserOp1;
3463204961Srdivacky  if (Instruction *I = dyn_cast<Instruction>(V)) {
3464193323Sed    Opcode = I->getOpcode();
3465204961Srdivacky
3466204961Srdivacky    // Don't attempt to analyze instructions in blocks that aren't
3467204961Srdivacky    // reachable. Such instructions don't matter, and they aren't required
3468204961Srdivacky    // to obey basic rules for definitions dominating uses which this
3469204961Srdivacky    // analysis depends on.
3470204961Srdivacky    if (!DT->isReachableFromEntry(I->getParent()))
3471204961Srdivacky      return getUnknown(V);
3472204961Srdivacky  } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
3473193323Sed    Opcode = CE->getOpcode();
3474195098Sed  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
3475195098Sed    return getConstant(CI);
3476195098Sed  else if (isa<ConstantPointerNull>(V))
3477207618Srdivacky    return getConstant(V->getType(), 0);
3478198090Srdivacky  else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
3479198090Srdivacky    return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee());
3480193323Sed  else
3481193323Sed    return getUnknown(V);
3482193323Sed
3483198090Srdivacky  Operator *U = cast<Operator>(V);
3484193323Sed  switch (Opcode) {
3485212904Sdim  case Instruction::Add: {
3486212904Sdim    // The simple thing to do would be to just call getSCEV on both operands
3487212904Sdim    // and call getAddExpr with the result. However if we're looking at a
3488212904Sdim    // bunch of things all added together, this can be quite inefficient,
3489212904Sdim    // because it leads to N-1 getAddExpr calls for N ultimate operands.
3490212904Sdim    // Instead, gather up all the operands and make a single getAddExpr call.
3491212904Sdim    // LLVM IR canonical form means we need only traverse the left operands.
3492212904Sdim    SmallVector<const SCEV *, 4> AddOps;
3493212904Sdim    AddOps.push_back(getSCEV(U->getOperand(1)));
3494212904Sdim    for (Value *Op = U->getOperand(0); ; Op = U->getOperand(0)) {
3495212904Sdim      unsigned Opcode = Op->getValueID() - Value::InstructionVal;
3496212904Sdim      if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
3497212904Sdim        break;
3498212904Sdim      U = cast<Operator>(Op);
3499212904Sdim      const SCEV *Op1 = getSCEV(U->getOperand(1));
3500212904Sdim      if (Opcode == Instruction::Sub)
3501212904Sdim        AddOps.push_back(getNegativeSCEV(Op1));
3502212904Sdim      else
3503212904Sdim        AddOps.push_back(Op1);
3504212904Sdim    }
3505212904Sdim    AddOps.push_back(getSCEV(U->getOperand(0)));
3506212904Sdim    return getAddExpr(AddOps);
3507212904Sdim  }
3508212904Sdim  case Instruction::Mul: {
3509212904Sdim    // See the Add code above.
3510212904Sdim    SmallVector<const SCEV *, 4> MulOps;
3511212904Sdim    MulOps.push_back(getSCEV(U->getOperand(1)));
3512212904Sdim    for (Value *Op = U->getOperand(0);
3513221345Sdim         Op->getValueID() == Instruction::Mul + Value::InstructionVal;
3514212904Sdim         Op = U->getOperand(0)) {
3515212904Sdim      U = cast<Operator>(Op);
3516212904Sdim      MulOps.push_back(getSCEV(U->getOperand(1)));
3517212904Sdim    }
3518212904Sdim    MulOps.push_back(getSCEV(U->getOperand(0)));
3519212904Sdim    return getMulExpr(MulOps);
3520212904Sdim  }
3521193323Sed  case Instruction::UDiv:
3522193323Sed    return getUDivExpr(getSCEV(U->getOperand(0)),
3523193323Sed                       getSCEV(U->getOperand(1)));
3524193323Sed  case Instruction::Sub:
3525193323Sed    return getMinusSCEV(getSCEV(U->getOperand(0)),
3526193323Sed                        getSCEV(U->getOperand(1)));
3527193323Sed  case Instruction::And:
3528193323Sed    // For an expression like x&255 that merely masks off the high bits,
3529193323Sed    // use zext(trunc(x)) as the SCEV expression.
3530193323Sed    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3531193323Sed      if (CI->isNullValue())
3532193323Sed        return getSCEV(U->getOperand(1));
3533193323Sed      if (CI->isAllOnesValue())
3534193323Sed        return getSCEV(U->getOperand(0));
3535193323Sed      const APInt &A = CI->getValue();
3536194612Sed
3537194612Sed      // Instcombine's ShrinkDemandedConstant may strip bits out of
3538194612Sed      // constants, obscuring what would otherwise be a low-bits mask.
3539194612Sed      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
3540194612Sed      // knew about to reconstruct a low-bits mask value.
3541194612Sed      unsigned LZ = A.countLeadingZeros();
3542194612Sed      unsigned BitWidth = A.getBitWidth();
3543194612Sed      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
3544194612Sed      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
3545194612Sed      ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
3546194612Sed
3547194612Sed      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
3548194612Sed
3549194612Sed      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
3550193323Sed        return
3551193323Sed          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
3552198090Srdivacky                                IntegerType::get(getContext(), BitWidth - LZ)),
3553193323Sed                            U->getType());
3554193323Sed    }
3555193323Sed    break;
3556194612Sed
3557193323Sed  case Instruction::Or:
3558193323Sed    // If the RHS of the Or is a constant, we may have something like:
3559193323Sed    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
3560193323Sed    // optimizations will transparently handle this case.
3561193323Sed    //
3562193323Sed    // In order for this transformation to be safe, the LHS must be of the
3563193323Sed    // form X*(2^n) and the Or constant must be less than 2^n.
3564193323Sed    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3565198090Srdivacky      const SCEV *LHS = getSCEV(U->getOperand(0));
3566193323Sed      const APInt &CIVal = CI->getValue();
3567194612Sed      if (GetMinTrailingZeros(LHS) >=
3568198090Srdivacky          (CIVal.getBitWidth() - CIVal.countLeadingZeros())) {
3569198090Srdivacky        // Build a plain add SCEV.
3570198090Srdivacky        const SCEV *S = getAddExpr(LHS, getSCEV(CI));
3571198090Srdivacky        // If the LHS of the add was an addrec and it has no-wrap flags,
3572198090Srdivacky        // transfer the no-wrap flags, since an or won't introduce a wrap.
3573198090Srdivacky        if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) {
3574198090Srdivacky          const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS);
3575221345Sdim          const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags(
3576221345Sdim            OldAR->getNoWrapFlags());
3577198090Srdivacky        }
3578198090Srdivacky        return S;
3579198090Srdivacky      }
3580193323Sed    }
3581193323Sed    break;
3582193323Sed  case Instruction::Xor:
3583193323Sed    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
3584193323Sed      // If the RHS of the xor is a signbit, then this is just an add.
3585193323Sed      // Instcombine turns add of signbit into xor as a strength reduction step.
3586193323Sed      if (CI->getValue().isSignBit())
3587193323Sed        return getAddExpr(getSCEV(U->getOperand(0)),
3588193323Sed                          getSCEV(U->getOperand(1)));
3589193323Sed
3590193323Sed      // If the RHS of xor is -1, then this is a not operation.
3591193323Sed      if (CI->isAllOnesValue())
3592193323Sed        return getNotSCEV(getSCEV(U->getOperand(0)));
3593193323Sed
3594193323Sed      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
3595193323Sed      // This is a variant of the check for xor with -1, and it handles
3596193323Sed      // the case where instcombine has trimmed non-demanded bits out
3597193323Sed      // of an xor with -1.
3598193323Sed      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
3599193323Sed        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
3600193323Sed          if (BO->getOpcode() == Instruction::And &&
3601193323Sed              LCI->getValue() == CI->getValue())
3602193323Sed            if (const SCEVZeroExtendExpr *Z =
3603194612Sed                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
3604194612Sed              const Type *UTy = U->getType();
3605198090Srdivacky              const SCEV *Z0 = Z->getOperand();
3606194612Sed              const Type *Z0Ty = Z0->getType();
3607194612Sed              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
3608194612Sed
3609204642Srdivacky              // If C is a low-bits mask, the zero extend is serving to
3610194612Sed              // mask off the high bits. Complement the operand and
3611194612Sed              // re-apply the zext.
3612194612Sed              if (APIntOps::isMask(Z0TySize, CI->getValue()))
3613194612Sed                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
3614194612Sed
3615194612Sed              // If C is a single bit, it may be in the sign-bit position
3616194612Sed              // before the zero-extend. In this case, represent the xor
3617194612Sed              // using an add, which is equivalent, and re-apply the zext.
3618218893Sdim              APInt Trunc = CI->getValue().trunc(Z0TySize);
3619218893Sdim              if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
3620194612Sed                  Trunc.isSignBit())
3621194612Sed                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
3622194612Sed                                         UTy);
3623194612Sed            }
3624193323Sed    }
3625193323Sed    break;
3626193323Sed
3627193323Sed  case Instruction::Shl:
3628193323Sed    // Turn shift left of a constant amount into a multiply.
3629193323Sed    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3630203954Srdivacky      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3631207618Srdivacky
3632207618Srdivacky      // If the shift count is not less than the bitwidth, the result of
3633207618Srdivacky      // the shift is undefined. Don't try to analyze it, because the
3634207618Srdivacky      // resolution chosen here may differ from the resolution chosen in
3635207618Srdivacky      // other parts of the compiler.
3636207618Srdivacky      if (SA->getValue().uge(BitWidth))
3637207618Srdivacky        break;
3638207618Srdivacky
3639198090Srdivacky      Constant *X = ConstantInt::get(getContext(),
3640207618Srdivacky        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3641193323Sed      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3642193323Sed    }
3643193323Sed    break;
3644193323Sed
3645193323Sed  case Instruction::LShr:
3646193323Sed    // Turn logical shift right of a constant into a unsigned divide.
3647193323Sed    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
3648203954Srdivacky      uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth();
3649207618Srdivacky
3650207618Srdivacky      // If the shift count is not less than the bitwidth, the result of
3651207618Srdivacky      // the shift is undefined. Don't try to analyze it, because the
3652207618Srdivacky      // resolution chosen here may differ from the resolution chosen in
3653207618Srdivacky      // other parts of the compiler.
3654207618Srdivacky      if (SA->getValue().uge(BitWidth))
3655207618Srdivacky        break;
3656207618Srdivacky
3657198090Srdivacky      Constant *X = ConstantInt::get(getContext(),
3658207618Srdivacky        APInt(BitWidth, 1).shl(SA->getZExtValue()));
3659193323Sed      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
3660193323Sed    }
3661193323Sed    break;
3662193323Sed
3663193323Sed  case Instruction::AShr:
3664193323Sed    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
3665193323Sed    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
3666207618Srdivacky      if (Operator *L = dyn_cast<Operator>(U->getOperand(0)))
3667193323Sed        if (L->getOpcode() == Instruction::Shl &&
3668193323Sed            L->getOperand(1) == U->getOperand(1)) {
3669207618Srdivacky          uint64_t BitWidth = getTypeSizeInBits(U->getType());
3670207618Srdivacky
3671207618Srdivacky          // If the shift count is not less than the bitwidth, the result of
3672207618Srdivacky          // the shift is undefined. Don't try to analyze it, because the
3673207618Srdivacky          // resolution chosen here may differ from the resolution chosen in
3674207618Srdivacky          // other parts of the compiler.
3675207618Srdivacky          if (CI->getValue().uge(BitWidth))
3676207618Srdivacky            break;
3677207618Srdivacky
3678193323Sed          uint64_t Amt = BitWidth - CI->getZExtValue();
3679193323Sed          if (Amt == BitWidth)
3680193323Sed            return getSCEV(L->getOperand(0));       // shift by zero --> noop
3681193323Sed          return
3682193323Sed            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
3683207618Srdivacky                                              IntegerType::get(getContext(),
3684207618Srdivacky                                                               Amt)),
3685207618Srdivacky                              U->getType());
3686193323Sed        }
3687193323Sed    break;
3688193323Sed
3689193323Sed  case Instruction::Trunc:
3690193323Sed    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
3691193323Sed
3692193323Sed  case Instruction::ZExt:
3693193323Sed    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3694193323Sed
3695193323Sed  case Instruction::SExt:
3696193323Sed    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
3697193323Sed
3698193323Sed  case Instruction::BitCast:
3699193323Sed    // BitCasts are no-op casts so we just eliminate the cast.
3700193323Sed    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
3701193323Sed      return getSCEV(U->getOperand(0));
3702193323Sed    break;
3703193323Sed
3704203954Srdivacky  // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
3705203954Srdivacky  // lead to pointer expressions which cannot safely be expanded to GEPs,
3706203954Srdivacky  // because ScalarEvolution doesn't respect the GEP aliasing rules when
3707203954Srdivacky  // simplifying integer expressions.
3708193323Sed
3709193323Sed  case Instruction::GetElementPtr:
3710201360Srdivacky    return createNodeForGEP(cast<GEPOperator>(U));
3711193323Sed
3712193323Sed  case Instruction::PHI:
3713193323Sed    return createNodeForPHI(cast<PHINode>(U));
3714193323Sed
3715193323Sed  case Instruction::Select:
3716193323Sed    // This could be a smax or umax that was lowered earlier.
3717193323Sed    // Try to recover it.
3718193323Sed    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
3719193323Sed      Value *LHS = ICI->getOperand(0);
3720193323Sed      Value *RHS = ICI->getOperand(1);
3721193323Sed      switch (ICI->getPredicate()) {
3722193323Sed      case ICmpInst::ICMP_SLT:
3723193323Sed      case ICmpInst::ICMP_SLE:
3724193323Sed        std::swap(LHS, RHS);
3725193323Sed        // fall through
3726193323Sed      case ICmpInst::ICMP_SGT:
3727193323Sed      case ICmpInst::ICMP_SGE:
3728207618Srdivacky        // a >s b ? a+x : b+x  ->  smax(a, b)+x
3729207618Srdivacky        // a >s b ? b+x : a+x  ->  smin(a, b)+x
3730207618Srdivacky        if (LHS->getType() == U->getType()) {
3731207618Srdivacky          const SCEV *LS = getSCEV(LHS);
3732207618Srdivacky          const SCEV *RS = getSCEV(RHS);
3733207618Srdivacky          const SCEV *LA = getSCEV(U->getOperand(1));
3734207618Srdivacky          const SCEV *RA = getSCEV(U->getOperand(2));
3735207618Srdivacky          const SCEV *LDiff = getMinusSCEV(LA, LS);
3736207618Srdivacky          const SCEV *RDiff = getMinusSCEV(RA, RS);
3737207618Srdivacky          if (LDiff == RDiff)
3738207618Srdivacky            return getAddExpr(getSMaxExpr(LS, RS), LDiff);
3739207618Srdivacky          LDiff = getMinusSCEV(LA, RS);
3740207618Srdivacky          RDiff = getMinusSCEV(RA, LS);
3741207618Srdivacky          if (LDiff == RDiff)
3742207618Srdivacky            return getAddExpr(getSMinExpr(LS, RS), LDiff);
3743207618Srdivacky        }
3744193323Sed        break;
3745193323Sed      case ICmpInst::ICMP_ULT:
3746193323Sed      case ICmpInst::ICMP_ULE:
3747193323Sed        std::swap(LHS, RHS);
3748193323Sed        // fall through
3749193323Sed      case ICmpInst::ICMP_UGT:
3750193323Sed      case ICmpInst::ICMP_UGE:
3751207618Srdivacky        // a >u b ? a+x : b+x  ->  umax(a, b)+x
3752207618Srdivacky        // a >u b ? b+x : a+x  ->  umin(a, b)+x
3753207618Srdivacky        if (LHS->getType() == U->getType()) {
3754207618Srdivacky          const SCEV *LS = getSCEV(LHS);
3755207618Srdivacky          const SCEV *RS = getSCEV(RHS);
3756207618Srdivacky          const SCEV *LA = getSCEV(U->getOperand(1));
3757207618Srdivacky          const SCEV *RA = getSCEV(U->getOperand(2));
3758207618Srdivacky          const SCEV *LDiff = getMinusSCEV(LA, LS);
3759207618Srdivacky          const SCEV *RDiff = getMinusSCEV(RA, RS);
3760207618Srdivacky          if (LDiff == RDiff)
3761207618Srdivacky            return getAddExpr(getUMaxExpr(LS, RS), LDiff);
3762207618Srdivacky          LDiff = getMinusSCEV(LA, RS);
3763207618Srdivacky          RDiff = getMinusSCEV(RA, LS);
3764207618Srdivacky          if (LDiff == RDiff)
3765207618Srdivacky            return getAddExpr(getUMinExpr(LS, RS), LDiff);
3766207618Srdivacky        }
3767193323Sed        break;
3768194612Sed      case ICmpInst::ICMP_NE:
3769207618Srdivacky        // n != 0 ? n+x : 1+x  ->  umax(n, 1)+x
3770207618Srdivacky        if (LHS->getType() == U->getType() &&
3771194612Sed            isa<ConstantInt>(RHS) &&
3772207618Srdivacky            cast<ConstantInt>(RHS)->isZero()) {
3773207618Srdivacky          const SCEV *One = getConstant(LHS->getType(), 1);
3774207618Srdivacky          const SCEV *LS = getSCEV(LHS);
3775207618Srdivacky          const SCEV *LA = getSCEV(U->getOperand(1));
3776207618Srdivacky          const SCEV *RA = getSCEV(U->getOperand(2));
3777207618Srdivacky          const SCEV *LDiff = getMinusSCEV(LA, LS);
3778207618Srdivacky          const SCEV *RDiff = getMinusSCEV(RA, One);
3779207618Srdivacky          if (LDiff == RDiff)
3780212904Sdim            return getAddExpr(getUMaxExpr(One, LS), LDiff);
3781207618Srdivacky        }
3782194612Sed        break;
3783194612Sed      case ICmpInst::ICMP_EQ:
3784207618Srdivacky        // n == 0 ? 1+x : n+x  ->  umax(n, 1)+x
3785207618Srdivacky        if (LHS->getType() == U->getType() &&
3786194612Sed            isa<ConstantInt>(RHS) &&
3787207618Srdivacky            cast<ConstantInt>(RHS)->isZero()) {
3788207618Srdivacky          const SCEV *One = getConstant(LHS->getType(), 1);
3789207618Srdivacky          const SCEV *LS = getSCEV(LHS);
3790207618Srdivacky          const SCEV *LA = getSCEV(U->getOperand(1));
3791207618Srdivacky          const SCEV *RA = getSCEV(U->getOperand(2));
3792207618Srdivacky          const SCEV *LDiff = getMinusSCEV(LA, One);
3793207618Srdivacky          const SCEV *RDiff = getMinusSCEV(RA, LS);
3794207618Srdivacky          if (LDiff == RDiff)
3795212904Sdim            return getAddExpr(getUMaxExpr(One, LS), LDiff);
3796207618Srdivacky        }
3797194612Sed        break;
3798193323Sed      default:
3799193323Sed        break;
3800193323Sed      }
3801193323Sed    }
3802193323Sed
3803193323Sed  default: // We cannot analyze this expression.
3804193323Sed    break;
3805193323Sed  }
3806193323Sed
3807193323Sed  return getUnknown(V);
3808193323Sed}
3809193323Sed
3810193323Sed
3811193323Sed
3812193323Sed//===----------------------------------------------------------------------===//
3813193323Sed//                   Iteration Count Computation Code
3814193323Sed//
3815193323Sed
3816193323Sed/// getBackedgeTakenCount - If the specified loop has a predictable
3817193323Sed/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
3818193323Sed/// object. The backedge-taken count is the number of times the loop header
3819193323Sed/// will be branched to from within the loop. This is one less than the
3820193323Sed/// trip count of the loop, since it doesn't count the first iteration,
3821193323Sed/// when the header is branched to from outside the loop.
3822193323Sed///
3823193323Sed/// Note that it is not valid to call this method on a loop without a
3824193323Sed/// loop-invariant backedge-taken count (see
3825193323Sed/// hasLoopInvariantBackedgeTakenCount).
3826193323Sed///
3827198090Srdivackyconst SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
3828193323Sed  return getBackedgeTakenInfo(L).Exact;
3829193323Sed}
3830193323Sed
3831193323Sed/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
3832193323Sed/// return the least SCEV value that is known never to be less than the
3833193323Sed/// actual backedge taken count.
3834198090Srdivackyconst SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
3835193323Sed  return getBackedgeTakenInfo(L).Max;
3836193323Sed}
3837193323Sed
3838198090Srdivacky/// PushLoopPHIs - Push PHI nodes in the header of the given loop
3839198090Srdivacky/// onto the given Worklist.
3840198090Srdivackystatic void
3841198090SrdivackyPushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
3842198090Srdivacky  BasicBlock *Header = L->getHeader();
3843198090Srdivacky
3844198090Srdivacky  // Push all Loop-header PHIs onto the Worklist stack.
3845198090Srdivacky  for (BasicBlock::iterator I = Header->begin();
3846198090Srdivacky       PHINode *PN = dyn_cast<PHINode>(I); ++I)
3847198090Srdivacky    Worklist.push_back(PN);
3848198090Srdivacky}
3849198090Srdivacky
3850193323Sedconst ScalarEvolution::BackedgeTakenInfo &
3851193323SedScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
3852193323Sed  // Initially insert a CouldNotCompute for this loop. If the insertion
3853204642Srdivacky  // succeeds, proceed to actually compute a backedge-taken count and
3854193323Sed  // update the value. The temporary CouldNotCompute value tells SCEV
3855193323Sed  // code elsewhere that it shouldn't attempt to request a new
3856193323Sed  // backedge-taken count, which could result in infinite recursion.
3857223017Sdim  std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair =
3858193323Sed    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
3859218893Sdim  if (!Pair.second)
3860218893Sdim    return Pair.first->second;
3861193323Sed
3862221345Sdim  BackedgeTakenInfo Result = getCouldNotCompute();
3863221345Sdim  BackedgeTakenInfo Computed = ComputeBackedgeTakenCount(L);
3864221345Sdim  if (Computed.Exact != getCouldNotCompute()) {
3865221345Sdim    assert(isLoopInvariant(Computed.Exact, L) &&
3866221345Sdim           isLoopInvariant(Computed.Max, L) &&
3867218893Sdim           "Computed backedge-taken count isn't loop invariant for loop!");
3868218893Sdim    ++NumTripCountsComputed;
3869218893Sdim
3870218893Sdim    // Update the value in the map.
3871221345Sdim    Result = Computed;
3872218893Sdim  } else {
3873221345Sdim    if (Computed.Max != getCouldNotCompute())
3874193323Sed      // Update the value in the map.
3875221345Sdim      Result = Computed;
3876218893Sdim    if (isa<PHINode>(L->getHeader()->begin()))
3877218893Sdim      // Only count loops that have phi nodes as not being computable.
3878218893Sdim      ++NumTripCountsNotComputed;
3879218893Sdim  }
3880193323Sed
3881218893Sdim  // Now that we know more about the trip count for this loop, forget any
3882218893Sdim  // existing SCEV values for PHI nodes in this loop since they are only
3883218893Sdim  // conservative estimates made without the benefit of trip count
3884218893Sdim  // information. This is similar to the code in forgetLoop, except that
3885218893Sdim  // it handles SCEVUnknown PHI nodes specially.
3886221345Sdim  if (Computed.hasAnyInfo()) {
3887218893Sdim    SmallVector<Instruction *, 16> Worklist;
3888218893Sdim    PushLoopPHIs(L, Worklist);
3889198090Srdivacky
3890218893Sdim    SmallPtrSet<Instruction *, 8> Visited;
3891218893Sdim    while (!Worklist.empty()) {
3892218893Sdim      Instruction *I = Worklist.pop_back_val();
3893218893Sdim      if (!Visited.insert(I)) continue;
3894198090Srdivacky
3895218893Sdim      ValueExprMapType::iterator It =
3896218893Sdim        ValueExprMap.find(static_cast<Value *>(I));
3897218893Sdim      if (It != ValueExprMap.end()) {
3898218893Sdim        const SCEV *Old = It->second;
3899218893Sdim
3900218893Sdim        // SCEVUnknown for a PHI either means that it has an unrecognized
3901218893Sdim        // structure, or it's a PHI that's in the progress of being computed
3902218893Sdim        // by createNodeForPHI.  In the former case, additional loop trip
3903218893Sdim        // count information isn't going to change anything. In the later
3904218893Sdim        // case, createNodeForPHI will perform the necessary updates on its
3905218893Sdim        // own when it gets to that point.
3906218893Sdim        if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) {
3907218893Sdim          forgetMemoizedResults(Old);
3908218893Sdim          ValueExprMap.erase(It);
3909198090Srdivacky        }
3910218893Sdim        if (PHINode *PN = dyn_cast<PHINode>(I))
3911218893Sdim          ConstantEvolutionLoopExitValue.erase(PN);
3912218893Sdim      }
3913198090Srdivacky
3914218893Sdim      PushDefUseChildren(I, Worklist);
3915198090Srdivacky    }
3916193323Sed  }
3917221345Sdim
3918221345Sdim  // Re-lookup the insert position, since the call to
3919221345Sdim  // ComputeBackedgeTakenCount above could result in a
3920221345Sdim  // recusive call to getBackedgeTakenInfo (on a different
3921221345Sdim  // loop), which would invalidate the iterator computed
3922221345Sdim  // earlier.
3923221345Sdim  return BackedgeTakenCounts.find(L)->second = Result;
3924193323Sed}
3925193323Sed
3926198892Srdivacky/// forgetLoop - This method should be called by the client when it has
3927198892Srdivacky/// changed a loop in a way that may effect ScalarEvolution's ability to
3928198892Srdivacky/// compute a trip count, or if the loop is deleted.
3929198892Srdivackyvoid ScalarEvolution::forgetLoop(const Loop *L) {
3930198892Srdivacky  // Drop any stored trip count value.
3931193323Sed  BackedgeTakenCounts.erase(L);
3932193323Sed
3933198892Srdivacky  // Drop information about expressions based on loop-header PHIs.
3934193323Sed  SmallVector<Instruction *, 16> Worklist;
3935198090Srdivacky  PushLoopPHIs(L, Worklist);
3936193323Sed
3937198090Srdivacky  SmallPtrSet<Instruction *, 8> Visited;
3938193323Sed  while (!Worklist.empty()) {
3939193323Sed    Instruction *I = Worklist.pop_back_val();
3940198090Srdivacky    if (!Visited.insert(I)) continue;
3941198090Srdivacky
3942212904Sdim    ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
3943212904Sdim    if (It != ValueExprMap.end()) {
3944218893Sdim      forgetMemoizedResults(It->second);
3945212904Sdim      ValueExprMap.erase(It);
3946198090Srdivacky      if (PHINode *PN = dyn_cast<PHINode>(I))
3947198090Srdivacky        ConstantEvolutionLoopExitValue.erase(PN);
3948198090Srdivacky    }
3949198090Srdivacky
3950198090Srdivacky    PushDefUseChildren(I, Worklist);
3951193323Sed  }
3952218893Sdim
3953218893Sdim  // Forget all contained loops too, to avoid dangling entries in the
3954218893Sdim  // ValuesAtScopes map.
3955218893Sdim  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
3956218893Sdim    forgetLoop(*I);
3957193323Sed}
3958193323Sed
3959204642Srdivacky/// forgetValue - This method should be called by the client when it has
3960204642Srdivacky/// changed a value in a way that may effect its value, or which may
3961204642Srdivacky/// disconnect it from a def-use chain linking it to a loop.
3962204642Srdivackyvoid ScalarEvolution::forgetValue(Value *V) {
3963204642Srdivacky  Instruction *I = dyn_cast<Instruction>(V);
3964204642Srdivacky  if (!I) return;
3965204642Srdivacky
3966204642Srdivacky  // Drop information about expressions based on loop-header PHIs.
3967204642Srdivacky  SmallVector<Instruction *, 16> Worklist;
3968204642Srdivacky  Worklist.push_back(I);
3969204642Srdivacky
3970204642Srdivacky  SmallPtrSet<Instruction *, 8> Visited;
3971204642Srdivacky  while (!Worklist.empty()) {
3972204642Srdivacky    I = Worklist.pop_back_val();
3973204642Srdivacky    if (!Visited.insert(I)) continue;
3974204642Srdivacky
3975212904Sdim    ValueExprMapType::iterator It = ValueExprMap.find(static_cast<Value *>(I));
3976212904Sdim    if (It != ValueExprMap.end()) {
3977218893Sdim      forgetMemoizedResults(It->second);
3978212904Sdim      ValueExprMap.erase(It);
3979204642Srdivacky      if (PHINode *PN = dyn_cast<PHINode>(I))
3980204642Srdivacky        ConstantEvolutionLoopExitValue.erase(PN);
3981204642Srdivacky    }
3982204642Srdivacky
3983204642Srdivacky    PushDefUseChildren(I, Worklist);
3984204642Srdivacky  }
3985204642Srdivacky}
3986204642Srdivacky
3987193323Sed/// ComputeBackedgeTakenCount - Compute the number of times the backedge
3988193323Sed/// of the specified loop will execute.
3989193323SedScalarEvolution::BackedgeTakenInfo
3990193323SedScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
3991201360Srdivacky  SmallVector<BasicBlock *, 8> ExitingBlocks;
3992194612Sed  L->getExitingBlocks(ExitingBlocks);
3993193323Sed
3994194612Sed  // Examine all exits and pick the most conservative values.
3995198090Srdivacky  const SCEV *BECount = getCouldNotCompute();
3996198090Srdivacky  const SCEV *MaxBECount = getCouldNotCompute();
3997194612Sed  bool CouldNotComputeBECount = false;
3998194612Sed  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
3999194612Sed    BackedgeTakenInfo NewBTI =
4000194612Sed      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
4001193323Sed
4002195340Sed    if (NewBTI.Exact == getCouldNotCompute()) {
4003194612Sed      // We couldn't compute an exact value for this exit, so
4004194710Sed      // we won't be able to compute an exact value for the loop.
4005194612Sed      CouldNotComputeBECount = true;
4006195340Sed      BECount = getCouldNotCompute();
4007194612Sed    } else if (!CouldNotComputeBECount) {
4008195340Sed      if (BECount == getCouldNotCompute())
4009194612Sed        BECount = NewBTI.Exact;
4010194612Sed      else
4011195098Sed        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
4012194612Sed    }
4013195340Sed    if (MaxBECount == getCouldNotCompute())
4014195098Sed      MaxBECount = NewBTI.Max;
4015195340Sed    else if (NewBTI.Max != getCouldNotCompute())
4016195098Sed      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
4017194612Sed  }
4018194612Sed
4019194612Sed  return BackedgeTakenInfo(BECount, MaxBECount);
4020194612Sed}
4021194612Sed
4022194612Sed/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
4023194612Sed/// of the specified loop will execute if it exits via the specified block.
4024194612SedScalarEvolution::BackedgeTakenInfo
4025194612SedScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
4026194612Sed                                                   BasicBlock *ExitingBlock) {
4027194612Sed
4028194612Sed  // Okay, we've chosen an exiting block.  See what condition causes us to
4029194612Sed  // exit at this block.
4030193323Sed  //
4031193323Sed  // FIXME: we should be able to handle switch instructions (with a single exit)
4032193323Sed  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
4033195340Sed  if (ExitBr == 0) return getCouldNotCompute();
4034193323Sed  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
4035195098Sed
4036193323Sed  // At this point, we know we have a conditional branch that determines whether
4037193323Sed  // the loop is exited.  However, we don't know if the branch is executed each
4038193323Sed  // time through the loop.  If not, then the execution count of the branch will
4039193323Sed  // not be equal to the trip count of the loop.
4040193323Sed  //
4041193323Sed  // Currently we check for this by checking to see if the Exit branch goes to
4042193323Sed  // the loop header.  If so, we know it will always execute the same number of
4043193323Sed  // times as the loop.  We also handle the case where the exit block *is* the
4044194612Sed  // loop header.  This is common for un-rotated loops.
4045194612Sed  //
4046194612Sed  // If both of those tests fail, walk up the unique predecessor chain to the
4047194612Sed  // header, stopping if there is an edge that doesn't exit the loop. If the
4048194612Sed  // header is reached, the execution count of the branch will be equal to the
4049194612Sed  // trip count of the loop.
4050194612Sed  //
4051194612Sed  //  More extensive analysis could be done to handle more cases here.
4052194612Sed  //
4053193323Sed  if (ExitBr->getSuccessor(0) != L->getHeader() &&
4054193323Sed      ExitBr->getSuccessor(1) != L->getHeader() &&
4055194612Sed      ExitBr->getParent() != L->getHeader()) {
4056194612Sed    // The simple checks failed, try climbing the unique predecessor chain
4057194612Sed    // up to the header.
4058194612Sed    bool Ok = false;
4059194612Sed    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
4060194612Sed      BasicBlock *Pred = BB->getUniquePredecessor();
4061194612Sed      if (!Pred)
4062195340Sed        return getCouldNotCompute();
4063194612Sed      TerminatorInst *PredTerm = Pred->getTerminator();
4064194612Sed      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
4065194612Sed        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
4066194612Sed        if (PredSucc == BB)
4067194612Sed          continue;
4068194612Sed        // If the predecessor has a successor that isn't BB and isn't
4069194612Sed        // outside the loop, assume the worst.
4070194612Sed        if (L->contains(PredSucc))
4071195340Sed          return getCouldNotCompute();
4072194612Sed      }
4073194612Sed      if (Pred == L->getHeader()) {
4074194612Sed        Ok = true;
4075194612Sed        break;
4076194612Sed      }
4077194612Sed      BB = Pred;
4078194612Sed    }
4079194612Sed    if (!Ok)
4080195340Sed      return getCouldNotCompute();
4081194612Sed  }
4082193323Sed
4083204642Srdivacky  // Proceed to the next level to examine the exit condition expression.
4084194612Sed  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
4085194612Sed                                               ExitBr->getSuccessor(0),
4086194612Sed                                               ExitBr->getSuccessor(1));
4087194612Sed}
4088194612Sed
4089194612Sed/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
4090194612Sed/// backedge of the specified loop will execute if its exit condition
4091194612Sed/// were a conditional branch of ExitCond, TBB, and FBB.
4092194612SedScalarEvolution::BackedgeTakenInfo
4093194612SedScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
4094194612Sed                                                       Value *ExitCond,
4095194612Sed                                                       BasicBlock *TBB,
4096194612Sed                                                       BasicBlock *FBB) {
4097195098Sed  // Check if the controlling expression for this loop is an And or Or.
4098194612Sed  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
4099194612Sed    if (BO->getOpcode() == Instruction::And) {
4100194612Sed      // Recurse on the operands of the and.
4101194612Sed      BackedgeTakenInfo BTI0 =
4102194612Sed        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
4103194612Sed      BackedgeTakenInfo BTI1 =
4104194612Sed        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
4105198090Srdivacky      const SCEV *BECount = getCouldNotCompute();
4106198090Srdivacky      const SCEV *MaxBECount = getCouldNotCompute();
4107194612Sed      if (L->contains(TBB)) {
4108194612Sed        // Both conditions must be true for the loop to continue executing.
4109194612Sed        // Choose the less conservative count.
4110195340Sed        if (BTI0.Exact == getCouldNotCompute() ||
4111195340Sed            BTI1.Exact == getCouldNotCompute())
4112195340Sed          BECount = getCouldNotCompute();
4113194710Sed        else
4114194710Sed          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
4115195340Sed        if (BTI0.Max == getCouldNotCompute())
4116194612Sed          MaxBECount = BTI1.Max;
4117195340Sed        else if (BTI1.Max == getCouldNotCompute())
4118194612Sed          MaxBECount = BTI0.Max;
4119194710Sed        else
4120194710Sed          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
4121194612Sed      } else {
4122212904Sdim        // Both conditions must be true at the same time for the loop to exit.
4123212904Sdim        // For now, be conservative.
4124194612Sed        assert(L->contains(FBB) && "Loop block has no successor in loop!");
4125212904Sdim        if (BTI0.Max == BTI1.Max)
4126212904Sdim          MaxBECount = BTI0.Max;
4127212904Sdim        if (BTI0.Exact == BTI1.Exact)
4128212904Sdim          BECount = BTI0.Exact;
4129194612Sed      }
4130194612Sed
4131194612Sed      return BackedgeTakenInfo(BECount, MaxBECount);
4132194612Sed    }
4133194612Sed    if (BO->getOpcode() == Instruction::Or) {
4134194612Sed      // Recurse on the operands of the or.
4135194612Sed      BackedgeTakenInfo BTI0 =
4136194612Sed        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
4137194612Sed      BackedgeTakenInfo BTI1 =
4138194612Sed        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
4139198090Srdivacky      const SCEV *BECount = getCouldNotCompute();
4140198090Srdivacky      const SCEV *MaxBECount = getCouldNotCompute();
4141194612Sed      if (L->contains(FBB)) {
4142194612Sed        // Both conditions must be false for the loop to continue executing.
4143194612Sed        // Choose the less conservative count.
4144195340Sed        if (BTI0.Exact == getCouldNotCompute() ||
4145195340Sed            BTI1.Exact == getCouldNotCompute())
4146195340Sed          BECount = getCouldNotCompute();
4147194710Sed        else
4148194710Sed          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
4149195340Sed        if (BTI0.Max == getCouldNotCompute())
4150194612Sed          MaxBECount = BTI1.Max;
4151195340Sed        else if (BTI1.Max == getCouldNotCompute())
4152194612Sed          MaxBECount = BTI0.Max;
4153194710Sed        else
4154194710Sed          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
4155194612Sed      } else {
4156212904Sdim        // Both conditions must be false at the same time for the loop to exit.
4157212904Sdim        // For now, be conservative.
4158194612Sed        assert(L->contains(TBB) && "Loop block has no successor in loop!");
4159212904Sdim        if (BTI0.Max == BTI1.Max)
4160212904Sdim          MaxBECount = BTI0.Max;
4161212904Sdim        if (BTI0.Exact == BTI1.Exact)
4162212904Sdim          BECount = BTI0.Exact;
4163194612Sed      }
4164194612Sed
4165194612Sed      return BackedgeTakenInfo(BECount, MaxBECount);
4166194612Sed    }
4167194612Sed  }
4168194612Sed
4169194612Sed  // With an icmp, it may be feasible to compute an exact backedge-taken count.
4170204642Srdivacky  // Proceed to the next level to examine the icmp.
4171194612Sed  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
4172194612Sed    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
4173194612Sed
4174204642Srdivacky  // Check for a constant condition. These are normally stripped out by
4175204642Srdivacky  // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to
4176204642Srdivacky  // preserve the CFG and is temporarily leaving constant conditions
4177204642Srdivacky  // in place.
4178204642Srdivacky  if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
4179204642Srdivacky    if (L->contains(FBB) == !CI->getZExtValue())
4180204642Srdivacky      // The backedge is always taken.
4181204642Srdivacky      return getCouldNotCompute();
4182204642Srdivacky    else
4183204642Srdivacky      // The backedge is never taken.
4184207618Srdivacky      return getConstant(CI->getType(), 0);
4185204642Srdivacky  }
4186204642Srdivacky
4187193323Sed  // If it's not an integer or pointer comparison then compute it the hard way.
4188194612Sed  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
4189194612Sed}
4190193323Sed
4191194612Sed/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
4192194612Sed/// backedge of the specified loop will execute if its exit condition
4193194612Sed/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
4194194612SedScalarEvolution::BackedgeTakenInfo
4195194612SedScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
4196194612Sed                                                           ICmpInst *ExitCond,
4197194612Sed                                                           BasicBlock *TBB,
4198194612Sed                                                           BasicBlock *FBB) {
4199194612Sed
4200193323Sed  // If the condition was exit on true, convert the condition to exit on false
4201193323Sed  ICmpInst::Predicate Cond;
4202194612Sed  if (!L->contains(FBB))
4203193323Sed    Cond = ExitCond->getPredicate();
4204193323Sed  else
4205193323Sed    Cond = ExitCond->getInversePredicate();
4206193323Sed
4207193323Sed  // Handle common loops like: for (X = "string"; *X; ++X)
4208193323Sed  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
4209193323Sed    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
4210204642Srdivacky      BackedgeTakenInfo ItCnt =
4211193323Sed        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
4212204642Srdivacky      if (ItCnt.hasAnyInfo())
4213204642Srdivacky        return ItCnt;
4214193323Sed    }
4215193323Sed
4216198090Srdivacky  const SCEV *LHS = getSCEV(ExitCond->getOperand(0));
4217198090Srdivacky  const SCEV *RHS = getSCEV(ExitCond->getOperand(1));
4218193323Sed
4219193323Sed  // Try to evaluate any dependencies out of the loop.
4220193323Sed  LHS = getSCEVAtScope(LHS, L);
4221193323Sed  RHS = getSCEVAtScope(RHS, L);
4222193323Sed
4223195098Sed  // At this point, we would like to compute how many iterations of the
4224193323Sed  // loop the predicate will return true for these inputs.
4225218893Sdim  if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {
4226193323Sed    // If there is a loop-invariant, force it into the RHS.
4227193323Sed    std::swap(LHS, RHS);
4228193323Sed    Cond = ICmpInst::getSwappedPredicate(Cond);
4229193323Sed  }
4230193323Sed
4231207618Srdivacky  // Simplify the operands before analyzing them.
4232207618Srdivacky  (void)SimplifyICmpOperands(Cond, LHS, RHS);
4233207618Srdivacky
4234193323Sed  // If we have a comparison of a chrec against a constant, try to use value
4235193323Sed  // ranges to answer this query.
4236193323Sed  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
4237193323Sed    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
4238193323Sed      if (AddRec->getLoop() == L) {
4239193323Sed        // Form the constant range.
4240193323Sed        ConstantRange CompRange(
4241193323Sed            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
4242193323Sed
4243198090Srdivacky        const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);
4244193323Sed        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
4245193323Sed      }
4246193323Sed
4247193323Sed  switch (Cond) {
4248193323Sed  case ICmpInst::ICMP_NE: {                     // while (X != Y)
4249193323Sed    // Convert to: while (X-Y != 0)
4250221345Sdim    BackedgeTakenInfo BTI = HowFarToZero(getMinusSCEV(LHS, RHS), L);
4251204642Srdivacky    if (BTI.hasAnyInfo()) return BTI;
4252193323Sed    break;
4253193323Sed  }
4254198090Srdivacky  case ICmpInst::ICMP_EQ: {                     // while (X == Y)
4255198090Srdivacky    // Convert to: while (X-Y == 0)
4256204642Srdivacky    BackedgeTakenInfo BTI = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
4257204642Srdivacky    if (BTI.hasAnyInfo()) return BTI;
4258193323Sed    break;
4259193323Sed  }
4260193323Sed  case ICmpInst::ICMP_SLT: {
4261193323Sed    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
4262193323Sed    if (BTI.hasAnyInfo()) return BTI;
4263193323Sed    break;
4264193323Sed  }
4265193323Sed  case ICmpInst::ICMP_SGT: {
4266193323Sed    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
4267193323Sed                                             getNotSCEV(RHS), L, true);
4268193323Sed    if (BTI.hasAnyInfo()) return BTI;
4269193323Sed    break;
4270193323Sed  }
4271193323Sed  case ICmpInst::ICMP_ULT: {
4272193323Sed    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
4273193323Sed    if (BTI.hasAnyInfo()) return BTI;
4274193323Sed    break;
4275193323Sed  }
4276193323Sed  case ICmpInst::ICMP_UGT: {
4277193323Sed    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
4278193323Sed                                             getNotSCEV(RHS), L, false);
4279193323Sed    if (BTI.hasAnyInfo()) return BTI;
4280193323Sed    break;
4281193323Sed  }
4282193323Sed  default:
4283193323Sed#if 0
4284201360Srdivacky    dbgs() << "ComputeBackedgeTakenCount ";
4285193323Sed    if (ExitCond->getOperand(0)->getType()->isUnsigned())
4286201360Srdivacky      dbgs() << "[unsigned] ";
4287201360Srdivacky    dbgs() << *LHS << "   "
4288195098Sed         << Instruction::getOpcodeName(Instruction::ICmp)
4289193323Sed         << "   " << *RHS << "\n";
4290193323Sed#endif
4291193323Sed    break;
4292193323Sed  }
4293193323Sed  return
4294194612Sed    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
4295193323Sed}
4296193323Sed
4297193323Sedstatic ConstantInt *
4298193323SedEvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
4299193323Sed                                ScalarEvolution &SE) {
4300198090Srdivacky  const SCEV *InVal = SE.getConstant(C);
4301198090Srdivacky  const SCEV *Val = AddRec->evaluateAtIteration(InVal, SE);
4302193323Sed  assert(isa<SCEVConstant>(Val) &&
4303193323Sed         "Evaluation of SCEV at constant didn't fold correctly?");
4304193323Sed  return cast<SCEVConstant>(Val)->getValue();
4305193323Sed}
4306193323Sed
4307193323Sed/// GetAddressedElementFromGlobal - Given a global variable with an initializer
4308193323Sed/// and a GEP expression (missing the pointer index) indexing into it, return
4309193323Sed/// the addressed element of the initializer or null if the index expression is
4310193323Sed/// invalid.
4311193323Sedstatic Constant *
4312199989SrdivackyGetAddressedElementFromGlobal(GlobalVariable *GV,
4313193323Sed                              const std::vector<ConstantInt*> &Indices) {
4314193323Sed  Constant *Init = GV->getInitializer();
4315193323Sed  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
4316193323Sed    uint64_t Idx = Indices[i]->getZExtValue();
4317193323Sed    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
4318193323Sed      assert(Idx < CS->getNumOperands() && "Bad struct index!");
4319193323Sed      Init = cast<Constant>(CS->getOperand(Idx));
4320193323Sed    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
4321193323Sed      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
4322193323Sed      Init = cast<Constant>(CA->getOperand(Idx));
4323193323Sed    } else if (isa<ConstantAggregateZero>(Init)) {
4324193323Sed      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
4325193323Sed        assert(Idx < STy->getNumElements() && "Bad struct index!");
4326193323Sed        Init = Constant::getNullValue(STy->getElementType(Idx));
4327193323Sed      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
4328193323Sed        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
4329193323Sed        Init = Constant::getNullValue(ATy->getElementType());
4330193323Sed      } else {
4331198090Srdivacky        llvm_unreachable("Unknown constant aggregate type!");
4332193323Sed      }
4333193323Sed      return 0;
4334193323Sed    } else {
4335193323Sed      return 0; // Unknown initializer type
4336193323Sed    }
4337193323Sed  }
4338193323Sed  return Init;
4339193323Sed}
4340193323Sed
4341193323Sed/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
4342193323Sed/// 'icmp op load X, cst', try to see if we can compute the backedge
4343193323Sed/// execution count.
4344204642SrdivackyScalarEvolution::BackedgeTakenInfo
4345195098SedScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
4346195098Sed                                                LoadInst *LI,
4347195098Sed                                                Constant *RHS,
4348195098Sed                                                const Loop *L,
4349195098Sed                                                ICmpInst::Predicate predicate) {
4350195340Sed  if (LI->isVolatile()) return getCouldNotCompute();
4351193323Sed
4352193323Sed  // Check to see if the loaded pointer is a getelementptr of a global.
4353204642Srdivacky  // TODO: Use SCEV instead of manually grubbing with GEPs.
4354193323Sed  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
4355195340Sed  if (!GEP) return getCouldNotCompute();
4356193323Sed
4357193323Sed  // Make sure that it is really a constant global we are gepping, with an
4358193323Sed  // initializer, and make sure the first IDX is really 0.
4359193323Sed  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
4360198090Srdivacky  if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() ||
4361193323Sed      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
4362193323Sed      !cast<Constant>(GEP->getOperand(1))->isNullValue())
4363195340Sed    return getCouldNotCompute();
4364193323Sed
4365193323Sed  // Okay, we allow one non-constant index into the GEP instruction.
4366193323Sed  Value *VarIdx = 0;
4367193323Sed  std::vector<ConstantInt*> Indexes;
4368193323Sed  unsigned VarIdxNum = 0;
4369193323Sed  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
4370193323Sed    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
4371193323Sed      Indexes.push_back(CI);
4372193323Sed    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
4373195340Sed      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
4374193323Sed      VarIdx = GEP->getOperand(i);
4375193323Sed      VarIdxNum = i-2;
4376193323Sed      Indexes.push_back(0);
4377193323Sed    }
4378193323Sed
4379193323Sed  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
4380193323Sed  // Check to see if X is a loop variant variable value now.
4381198090Srdivacky  const SCEV *Idx = getSCEV(VarIdx);
4382193323Sed  Idx = getSCEVAtScope(Idx, L);
4383193323Sed
4384193323Sed  // We can only recognize very limited forms of loop index expressions, in
4385193323Sed  // particular, only affine AddRec's like {C1,+,C2}.
4386193323Sed  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
4387218893Sdim  if (!IdxExpr || !IdxExpr->isAffine() || isLoopInvariant(IdxExpr, L) ||
4388193323Sed      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
4389193323Sed      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
4390195340Sed    return getCouldNotCompute();
4391193323Sed
4392193323Sed  unsigned MaxSteps = MaxBruteForceIterations;
4393193323Sed  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
4394198090Srdivacky    ConstantInt *ItCst = ConstantInt::get(
4395198090Srdivacky                           cast<IntegerType>(IdxExpr->getType()), IterationNum);
4396193323Sed    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
4397193323Sed
4398193323Sed    // Form the GEP offset.
4399193323Sed    Indexes[VarIdxNum] = Val;
4400193323Sed
4401199989Srdivacky    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
4402193323Sed    if (Result == 0) break;  // Cannot compute!
4403193323Sed
4404193323Sed    // Evaluate the condition for this iteration.
4405193323Sed    Result = ConstantExpr::getICmp(predicate, Result, RHS);
4406193323Sed    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
4407193323Sed    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
4408193323Sed#if 0
4409201360Srdivacky      dbgs() << "\n***\n*** Computed loop count " << *ItCst
4410193323Sed             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
4411193323Sed             << "***\n";
4412193323Sed#endif
4413193323Sed      ++NumArrayLenItCounts;
4414193323Sed      return getConstant(ItCst);   // Found terminating iteration!
4415193323Sed    }
4416193323Sed  }
4417195340Sed  return getCouldNotCompute();
4418193323Sed}
4419193323Sed
4420193323Sed
4421193323Sed/// CanConstantFold - Return true if we can constant fold an instruction of the
4422193323Sed/// specified type, assuming that all operands were constants.
4423193323Sedstatic bool CanConstantFold(const Instruction *I) {
4424193323Sed  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
4425193323Sed      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
4426193323Sed    return true;
4427193323Sed
4428193323Sed  if (const CallInst *CI = dyn_cast<CallInst>(I))
4429193323Sed    if (const Function *F = CI->getCalledFunction())
4430193323Sed      return canConstantFoldCallTo(F);
4431193323Sed  return false;
4432193323Sed}
4433193323Sed
4434193323Sed/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
4435193323Sed/// in the loop that V is derived from.  We allow arbitrary operations along the
4436193323Sed/// way, but the operands of an operation must either be constants or a value
4437193323Sed/// derived from a constant PHI.  If this expression does not fit with these
4438193323Sed/// constraints, return null.
4439193323Sedstatic PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
4440193323Sed  // If this is not an instruction, or if this is an instruction outside of the
4441193323Sed  // loop, it can't be derived from a loop PHI.
4442193323Sed  Instruction *I = dyn_cast<Instruction>(V);
4443201360Srdivacky  if (I == 0 || !L->contains(I)) return 0;
4444193323Sed
4445193323Sed  if (PHINode *PN = dyn_cast<PHINode>(I)) {
4446193323Sed    if (L->getHeader() == I->getParent())
4447193323Sed      return PN;
4448193323Sed    else
4449193323Sed      // We don't currently keep track of the control flow needed to evaluate
4450193323Sed      // PHIs, so we cannot handle PHIs inside of loops.
4451193323Sed      return 0;
4452193323Sed  }
4453193323Sed
4454193323Sed  // If we won't be able to constant fold this expression even if the operands
4455193323Sed  // are constants, return early.
4456193323Sed  if (!CanConstantFold(I)) return 0;
4457193323Sed
4458193323Sed  // Otherwise, we can evaluate this instruction if all of its operands are
4459193323Sed  // constant or derived from a PHI node themselves.
4460193323Sed  PHINode *PHI = 0;
4461193323Sed  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
4462210299Sed    if (!isa<Constant>(I->getOperand(Op))) {
4463193323Sed      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
4464193323Sed      if (P == 0) return 0;  // Not evolving from PHI
4465193323Sed      if (PHI == 0)
4466193323Sed        PHI = P;
4467193323Sed      else if (PHI != P)
4468193323Sed        return 0;  // Evolving from multiple different PHIs.
4469193323Sed    }
4470193323Sed
4471193323Sed  // This is a expression evolving from a constant PHI!
4472193323Sed  return PHI;
4473193323Sed}
4474193323Sed
4475193323Sed/// EvaluateExpression - Given an expression that passes the
4476193323Sed/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
4477193323Sed/// in the loop has the value PHIVal.  If we can't fold this expression for some
4478193323Sed/// reason, return null.
4479199481Srdivackystatic Constant *EvaluateExpression(Value *V, Constant *PHIVal,
4480199481Srdivacky                                    const TargetData *TD) {
4481193323Sed  if (isa<PHINode>(V)) return PHIVal;
4482193323Sed  if (Constant *C = dyn_cast<Constant>(V)) return C;
4483193323Sed  Instruction *I = cast<Instruction>(V);
4484193323Sed
4485210299Sed  std::vector<Constant*> Operands(I->getNumOperands());
4486193323Sed
4487193323Sed  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4488199481Srdivacky    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal, TD);
4489193323Sed    if (Operands[i] == 0) return 0;
4490193323Sed  }
4491193323Sed
4492193323Sed  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4493199481Srdivacky    return ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
4494199481Srdivacky                                           Operands[1], TD);
4495199481Srdivacky  return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4496199481Srdivacky                                  &Operands[0], Operands.size(), TD);
4497193323Sed}
4498193323Sed
4499193323Sed/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
4500193323Sed/// in the header of its containing loop, we know the loop executes a
4501193323Sed/// constant number of times, and the PHI node is just a recurrence
4502193323Sed/// involving constants, fold it.
4503195098SedConstant *
4504195098SedScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
4505201360Srdivacky                                                   const APInt &BEs,
4506195098Sed                                                   const Loop *L) {
4507223017Sdim  DenseMap<PHINode*, Constant*>::const_iterator I =
4508193323Sed    ConstantEvolutionLoopExitValue.find(PN);
4509193323Sed  if (I != ConstantEvolutionLoopExitValue.end())
4510193323Sed    return I->second;
4511193323Sed
4512207618Srdivacky  if (BEs.ugt(MaxBruteForceIterations))
4513193323Sed    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
4514193323Sed
4515193323Sed  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
4516193323Sed
4517193323Sed  // Since the loop is canonicalized, the PHI node must have two entries.  One
4518193323Sed  // entry must be a constant (coming in from outside of the loop), and the
4519193323Sed  // second must be derived from the same PHI.
4520193323Sed  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4521193323Sed  Constant *StartCST =
4522193323Sed    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4523193323Sed  if (StartCST == 0)
4524193323Sed    return RetVal = 0;  // Must be a constant.
4525193323Sed
4526193323Sed  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4527210299Sed  if (getConstantEvolvingPHI(BEValue, L) != PN &&
4528210299Sed      !isa<Constant>(BEValue))
4529193323Sed    return RetVal = 0;  // Not derived from same PHI.
4530193323Sed
4531193323Sed  // Execute the loop symbolically to determine the exit value.
4532193323Sed  if (BEs.getActiveBits() >= 32)
4533193323Sed    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
4534193323Sed
4535193323Sed  unsigned NumIterations = BEs.getZExtValue(); // must be in range
4536193323Sed  unsigned IterationNum = 0;
4537193323Sed  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
4538193323Sed    if (IterationNum == NumIterations)
4539193323Sed      return RetVal = PHIVal;  // Got exit value!
4540193323Sed
4541193323Sed    // Compute the value of the PHI node for the next iteration.
4542199481Srdivacky    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4543193323Sed    if (NextPHI == PHIVal)
4544193323Sed      return RetVal = NextPHI;  // Stopped evolving!
4545193323Sed    if (NextPHI == 0)
4546193323Sed      return 0;        // Couldn't evaluate!
4547193323Sed    PHIVal = NextPHI;
4548193323Sed  }
4549193323Sed}
4550193323Sed
4551198090Srdivacky/// ComputeBackedgeTakenCountExhaustively - If the loop is known to execute a
4552193323Sed/// constant number of times (the condition evolves only from constants),
4553193323Sed/// try to evaluate a few iterations of the loop until we get the exit
4554193323Sed/// condition gets a value of ExitWhen (true or false).  If we cannot
4555195340Sed/// evaluate the trip count of the loop, return getCouldNotCompute().
4556195098Sedconst SCEV *
4557195098SedScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
4558195098Sed                                                       Value *Cond,
4559195098Sed                                                       bool ExitWhen) {
4560193323Sed  PHINode *PN = getConstantEvolvingPHI(Cond, L);
4561195340Sed  if (PN == 0) return getCouldNotCompute();
4562193323Sed
4563210299Sed  // If the loop is canonicalized, the PHI will have exactly two entries.
4564210299Sed  // That's the only form we support here.
4565210299Sed  if (PN->getNumIncomingValues() != 2) return getCouldNotCompute();
4566210299Sed
4567210299Sed  // One entry must be a constant (coming in from outside of the loop), and the
4568193323Sed  // second must be derived from the same PHI.
4569193323Sed  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
4570193323Sed  Constant *StartCST =
4571193323Sed    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
4572195340Sed  if (StartCST == 0) return getCouldNotCompute();  // Must be a constant.
4573193323Sed
4574193323Sed  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
4575210299Sed  if (getConstantEvolvingPHI(BEValue, L) != PN &&
4576210299Sed      !isa<Constant>(BEValue))
4577210299Sed    return getCouldNotCompute();  // Not derived from same PHI.
4578193323Sed
4579193323Sed  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
4580193323Sed  // the loop symbolically to determine when the condition gets a value of
4581193323Sed  // "ExitWhen".
4582193323Sed  unsigned IterationNum = 0;
4583193323Sed  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
4584193323Sed  for (Constant *PHIVal = StartCST;
4585193323Sed       IterationNum != MaxIterations; ++IterationNum) {
4586193323Sed    ConstantInt *CondVal =
4587199481Srdivacky      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal, TD));
4588193323Sed
4589193323Sed    // Couldn't symbolically evaluate.
4590195340Sed    if (!CondVal) return getCouldNotCompute();
4591193323Sed
4592193323Sed    if (CondVal->getValue() == uint64_t(ExitWhen)) {
4593193323Sed      ++NumBruteForceTripCountsComputed;
4594198090Srdivacky      return getConstant(Type::getInt32Ty(getContext()), IterationNum);
4595193323Sed    }
4596193323Sed
4597193323Sed    // Compute the value of the PHI node for the next iteration.
4598199481Srdivacky    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal, TD);
4599193323Sed    if (NextPHI == 0 || NextPHI == PHIVal)
4600195340Sed      return getCouldNotCompute();// Couldn't evaluate or not making progress...
4601193323Sed    PHIVal = NextPHI;
4602193323Sed  }
4603193323Sed
4604193323Sed  // Too many iterations were needed to evaluate.
4605195340Sed  return getCouldNotCompute();
4606193323Sed}
4607193323Sed
4608198090Srdivacky/// getSCEVAtScope - Return a SCEV expression for the specified value
4609193323Sed/// at the specified scope in the program.  The L value specifies a loop
4610193323Sed/// nest to evaluate the expression at, where null is the top-level or a
4611193323Sed/// specified loop is immediately inside of the loop.
4612193323Sed///
4613193323Sed/// This method can be used to compute the exit value for a variable defined
4614193323Sed/// in a loop by querying what the value will hold in the parent loop.
4615193323Sed///
4616193323Sed/// In the case that a relevant loop exit value cannot be computed, the
4617193323Sed/// original value V is returned.
4618198090Srdivackyconst SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
4619198090Srdivacky  // Check to see if we've folded this expression at this loop before.
4620198090Srdivacky  std::map<const Loop *, const SCEV *> &Values = ValuesAtScopes[V];
4621198090Srdivacky  std::pair<std::map<const Loop *, const SCEV *>::iterator, bool> Pair =
4622198090Srdivacky    Values.insert(std::make_pair(L, static_cast<const SCEV *>(0)));
4623198090Srdivacky  if (!Pair.second)
4624198090Srdivacky    return Pair.first->second ? Pair.first->second : V;
4625193323Sed
4626198090Srdivacky  // Otherwise compute it.
4627198090Srdivacky  const SCEV *C = computeSCEVAtScope(V, L);
4628198090Srdivacky  ValuesAtScopes[V][L] = C;
4629198090Srdivacky  return C;
4630198090Srdivacky}
4631198090Srdivacky
4632198090Srdivackyconst SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
4633193323Sed  if (isa<SCEVConstant>(V)) return V;
4634193323Sed
4635193323Sed  // If this instruction is evolved from a constant-evolving PHI, compute the
4636193323Sed  // exit value from the loop without using SCEVs.
4637193323Sed  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
4638193323Sed    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
4639193323Sed      const Loop *LI = (*this->LI)[I->getParent()];
4640193323Sed      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
4641193323Sed        if (PHINode *PN = dyn_cast<PHINode>(I))
4642193323Sed          if (PN->getParent() == LI->getHeader()) {
4643193323Sed            // Okay, there is no closed form solution for the PHI node.  Check
4644193323Sed            // to see if the loop that contains it has a known backedge-taken
4645193323Sed            // count.  If so, we may be able to force computation of the exit
4646193323Sed            // value.
4647198090Srdivacky            const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
4648193323Sed            if (const SCEVConstant *BTCC =
4649193323Sed                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
4650193323Sed              // Okay, we know how many times the containing loop executes.  If
4651193323Sed              // this is a constant evolving PHI node, get the final value at
4652193323Sed              // the specified iteration number.
4653193323Sed              Constant *RV = getConstantEvolutionLoopExitValue(PN,
4654193323Sed                                                   BTCC->getValue()->getValue(),
4655193323Sed                                                               LI);
4656195340Sed              if (RV) return getSCEV(RV);
4657193323Sed            }
4658193323Sed          }
4659193323Sed
4660193323Sed      // Okay, this is an expression that we cannot symbolically evaluate
4661193323Sed      // into a SCEV.  Check to see if it's possible to symbolically evaluate
4662193323Sed      // the arguments into constants, and if so, try to constant propagate the
4663193323Sed      // result.  This is particularly useful for computing loop exit values.
4664193323Sed      if (CanConstantFold(I)) {
4665210299Sed        SmallVector<Constant *, 4> Operands;
4666210299Sed        bool MadeImprovement = false;
4667193323Sed        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
4668193323Sed          Value *Op = I->getOperand(i);
4669193323Sed          if (Constant *C = dyn_cast<Constant>(Op)) {
4670193323Sed            Operands.push_back(C);
4671210299Sed            continue;
4672210299Sed          }
4673193323Sed
4674210299Sed          // If any of the operands is non-constant and if they are
4675210299Sed          // non-integer and non-pointer, don't even try to analyze them
4676210299Sed          // with scev techniques.
4677210299Sed          if (!isSCEVable(Op->getType()))
4678210299Sed            return V;
4679210299Sed
4680210299Sed          const SCEV *OrigV = getSCEV(Op);
4681210299Sed          const SCEV *OpV = getSCEVAtScope(OrigV, L);
4682210299Sed          MadeImprovement |= OrigV != OpV;
4683210299Sed
4684210299Sed          Constant *C = 0;
4685210299Sed          if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV))
4686210299Sed            C = SC->getValue();
4687210299Sed          if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV))
4688210299Sed            C = dyn_cast<Constant>(SU->getValue());
4689210299Sed          if (!C) return V;
4690210299Sed          if (C->getType() != Op->getType())
4691210299Sed            C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
4692210299Sed                                                              Op->getType(),
4693210299Sed                                                              false),
4694210299Sed                                      C, Op->getType());
4695210299Sed          Operands.push_back(C);
4696193323Sed        }
4697195098Sed
4698210299Sed        // Check to see if getSCEVAtScope actually made an improvement.
4699210299Sed        if (MadeImprovement) {
4700210299Sed          Constant *C = 0;
4701210299Sed          if (const CmpInst *CI = dyn_cast<CmpInst>(I))
4702210299Sed            C = ConstantFoldCompareInstOperands(CI->getPredicate(),
4703210299Sed                                                Operands[0], Operands[1], TD);
4704210299Sed          else
4705210299Sed            C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
4706210299Sed                                         &Operands[0], Operands.size(), TD);
4707210299Sed          if (!C) return V;
4708204642Srdivacky          return getSCEV(C);
4709210299Sed        }
4710193323Sed      }
4711193323Sed    }
4712193323Sed
4713193323Sed    // This is some other type of SCEVUnknown, just return it.
4714193323Sed    return V;
4715193323Sed  }
4716193323Sed
4717193323Sed  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
4718193323Sed    // Avoid performing the look-up in the common case where the specified
4719193323Sed    // expression has no loop-variant portions.
4720193323Sed    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
4721198090Srdivacky      const SCEV *OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4722193323Sed      if (OpAtScope != Comm->getOperand(i)) {
4723193323Sed        // Okay, at least one of these operands is loop variant but might be
4724193323Sed        // foldable.  Build a new instance of the folded commutative expression.
4725195098Sed        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
4726195098Sed                                            Comm->op_begin()+i);
4727193323Sed        NewOps.push_back(OpAtScope);
4728193323Sed
4729193323Sed        for (++i; i != e; ++i) {
4730193323Sed          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
4731193323Sed          NewOps.push_back(OpAtScope);
4732193323Sed        }
4733193323Sed        if (isa<SCEVAddExpr>(Comm))
4734193323Sed          return getAddExpr(NewOps);
4735193323Sed        if (isa<SCEVMulExpr>(Comm))
4736193323Sed          return getMulExpr(NewOps);
4737193323Sed        if (isa<SCEVSMaxExpr>(Comm))
4738193323Sed          return getSMaxExpr(NewOps);
4739193323Sed        if (isa<SCEVUMaxExpr>(Comm))
4740193323Sed          return getUMaxExpr(NewOps);
4741198090Srdivacky        llvm_unreachable("Unknown commutative SCEV type!");
4742193323Sed      }
4743193323Sed    }
4744193323Sed    // If we got here, all operands are loop invariant.
4745193323Sed    return Comm;
4746193323Sed  }
4747193323Sed
4748193323Sed  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
4749198090Srdivacky    const SCEV *LHS = getSCEVAtScope(Div->getLHS(), L);
4750198090Srdivacky    const SCEV *RHS = getSCEVAtScope(Div->getRHS(), L);
4751193323Sed    if (LHS == Div->getLHS() && RHS == Div->getRHS())
4752193323Sed      return Div;   // must be loop invariant
4753193323Sed    return getUDivExpr(LHS, RHS);
4754193323Sed  }
4755193323Sed
4756193323Sed  // If this is a loop recurrence for a loop that does not contain L, then we
4757193323Sed  // are dealing with the final value computed by the loop.
4758193323Sed  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
4759210299Sed    // First, attempt to evaluate each operand.
4760210299Sed    // Avoid performing the look-up in the common case where the specified
4761210299Sed    // expression has no loop-variant portions.
4762210299Sed    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
4763210299Sed      const SCEV *OpAtScope = getSCEVAtScope(AddRec->getOperand(i), L);
4764210299Sed      if (OpAtScope == AddRec->getOperand(i))
4765210299Sed        continue;
4766210299Sed
4767210299Sed      // Okay, at least one of these operands is loop variant but might be
4768210299Sed      // foldable.  Build a new instance of the folded commutative expression.
4769210299Sed      SmallVector<const SCEV *, 8> NewOps(AddRec->op_begin(),
4770210299Sed                                          AddRec->op_begin()+i);
4771210299Sed      NewOps.push_back(OpAtScope);
4772210299Sed      for (++i; i != e; ++i)
4773210299Sed        NewOps.push_back(getSCEVAtScope(AddRec->getOperand(i), L));
4774210299Sed
4775221345Sdim      const SCEV *FoldedRec =
4776221345Sdim        getAddRecExpr(NewOps, AddRec->getLoop(),
4777221345Sdim                      AddRec->getNoWrapFlags(SCEV::FlagNW));
4778221345Sdim      AddRec = dyn_cast<SCEVAddRecExpr>(FoldedRec);
4779221345Sdim      // The addrec may be folded to a nonrecurrence, for example, if the
4780221345Sdim      // induction variable is multiplied by zero after constant folding. Go
4781221345Sdim      // ahead and return the folded value.
4782221345Sdim      if (!AddRec)
4783221345Sdim        return FoldedRec;
4784210299Sed      break;
4785210299Sed    }
4786210299Sed
4787210299Sed    // If the scope is outside the addrec's loop, evaluate it by using the
4788210299Sed    // loop exit value of the addrec.
4789210299Sed    if (!AddRec->getLoop()->contains(L)) {
4790193323Sed      // To evaluate this recurrence, we need to know how many times the AddRec
4791193323Sed      // loop iterates.  Compute this now.
4792198090Srdivacky      const SCEV *BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
4793195340Sed      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
4794193323Sed
4795193323Sed      // Then, evaluate the AddRec.
4796193323Sed      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
4797193323Sed    }
4798210299Sed
4799193323Sed    return AddRec;
4800193323Sed  }
4801193323Sed
4802193323Sed  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
4803198090Srdivacky    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4804193323Sed    if (Op == Cast->getOperand())
4805193323Sed      return Cast;  // must be loop invariant
4806193323Sed    return getZeroExtendExpr(Op, Cast->getType());
4807193323Sed  }
4808193323Sed
4809193323Sed  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
4810198090Srdivacky    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4811193323Sed    if (Op == Cast->getOperand())
4812193323Sed      return Cast;  // must be loop invariant
4813193323Sed    return getSignExtendExpr(Op, Cast->getType());
4814193323Sed  }
4815193323Sed
4816193323Sed  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
4817198090Srdivacky    const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
4818193323Sed    if (Op == Cast->getOperand())
4819193323Sed      return Cast;  // must be loop invariant
4820193323Sed    return getTruncateExpr(Op, Cast->getType());
4821193323Sed  }
4822193323Sed
4823198090Srdivacky  llvm_unreachable("Unknown SCEV type!");
4824193323Sed  return 0;
4825193323Sed}
4826193323Sed
4827193323Sed/// getSCEVAtScope - This is a convenience function which does
4828193323Sed/// getSCEVAtScope(getSCEV(V), L).
4829198090Srdivackyconst SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
4830193323Sed  return getSCEVAtScope(getSCEV(V), L);
4831193323Sed}
4832193323Sed
4833193323Sed/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
4834193323Sed/// following equation:
4835193323Sed///
4836193323Sed///     A * X = B (mod N)
4837193323Sed///
4838193323Sed/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
4839193323Sed/// A and B isn't important.
4840193323Sed///
4841193323Sed/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
4842198090Srdivackystatic const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
4843193323Sed                                               ScalarEvolution &SE) {
4844193323Sed  uint32_t BW = A.getBitWidth();
4845193323Sed  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
4846193323Sed  assert(A != 0 && "A must be non-zero.");
4847193323Sed
4848193323Sed  // 1. D = gcd(A, N)
4849193323Sed  //
4850193323Sed  // The gcd of A and N may have only one prime factor: 2. The number of
4851193323Sed  // trailing zeros in A is its multiplicity
4852193323Sed  uint32_t Mult2 = A.countTrailingZeros();
4853193323Sed  // D = 2^Mult2
4854193323Sed
4855193323Sed  // 2. Check if B is divisible by D.
4856193323Sed  //
4857193323Sed  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
4858193323Sed  // is not less than multiplicity of this prime factor for D.
4859193323Sed  if (B.countTrailingZeros() < Mult2)
4860193323Sed    return SE.getCouldNotCompute();
4861193323Sed
4862193323Sed  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
4863193323Sed  // modulo (N / D).
4864193323Sed  //
4865193323Sed  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
4866193323Sed  // bit width during computations.
4867193323Sed  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
4868193323Sed  APInt Mod(BW + 1, 0);
4869218893Sdim  Mod.setBit(BW - Mult2);  // Mod = N / D
4870193323Sed  APInt I = AD.multiplicativeInverse(Mod);
4871193323Sed
4872193323Sed  // 4. Compute the minimum unsigned root of the equation:
4873193323Sed  // I * (B / D) mod (N / D)
4874193323Sed  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
4875193323Sed
4876193323Sed  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
4877193323Sed  // bits.
4878193323Sed  return SE.getConstant(Result.trunc(BW));
4879193323Sed}
4880193323Sed
4881193323Sed/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
4882193323Sed/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
4883193323Sed/// might be the same) or two SCEVCouldNotCompute objects.
4884193323Sed///
4885198090Srdivackystatic std::pair<const SCEV *,const SCEV *>
4886193323SedSolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
4887193323Sed  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
4888193323Sed  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
4889193323Sed  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
4890193323Sed  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
4891193323Sed
4892193323Sed  // We currently can only solve this if the coefficients are constants.
4893193323Sed  if (!LC || !MC || !NC) {
4894193323Sed    const SCEV *CNC = SE.getCouldNotCompute();
4895193323Sed    return std::make_pair(CNC, CNC);
4896193323Sed  }
4897193323Sed
4898193323Sed  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
4899193323Sed  const APInt &L = LC->getValue()->getValue();
4900193323Sed  const APInt &M = MC->getValue()->getValue();
4901193323Sed  const APInt &N = NC->getValue()->getValue();
4902193323Sed  APInt Two(BitWidth, 2);
4903193323Sed  APInt Four(BitWidth, 4);
4904193323Sed
4905195098Sed  {
4906193323Sed    using namespace APIntOps;
4907193323Sed    const APInt& C = L;
4908193323Sed    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
4909193323Sed    // The B coefficient is M-N/2
4910193323Sed    APInt B(M);
4911193323Sed    B -= sdiv(N,Two);
4912193323Sed
4913193323Sed    // The A coefficient is N/2
4914193323Sed    APInt A(N.sdiv(Two));
4915193323Sed
4916193323Sed    // Compute the B^2-4ac term.
4917193323Sed    APInt SqrtTerm(B);
4918193323Sed    SqrtTerm *= B;
4919193323Sed    SqrtTerm -= Four * (A * C);
4920193323Sed
4921193323Sed    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
4922193323Sed    // integer value or else APInt::sqrt() will assert.
4923193323Sed    APInt SqrtVal(SqrtTerm.sqrt());
4924193323Sed
4925195098Sed    // Compute the two solutions for the quadratic formula.
4926193323Sed    // The divisions must be performed as signed divisions.
4927193323Sed    APInt NegB(-B);
4928193323Sed    APInt TwoA( A << 1 );
4929193323Sed    if (TwoA.isMinValue()) {
4930193323Sed      const SCEV *CNC = SE.getCouldNotCompute();
4931193323Sed      return std::make_pair(CNC, CNC);
4932193323Sed    }
4933193323Sed
4934198090Srdivacky    LLVMContext &Context = SE.getContext();
4935193323Sed
4936198090Srdivacky    ConstantInt *Solution1 =
4937198090Srdivacky      ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
4938198090Srdivacky    ConstantInt *Solution2 =
4939198090Srdivacky      ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
4940198090Srdivacky
4941195098Sed    return std::make_pair(SE.getConstant(Solution1),
4942193323Sed                          SE.getConstant(Solution2));
4943193323Sed    } // end APIntOps namespace
4944193323Sed}
4945193323Sed
4946193323Sed/// HowFarToZero - Return the number of times a backedge comparing the specified
4947193630Sed/// value to zero will execute.  If not computable, return CouldNotCompute.
4948221345Sdim///
4949221345Sdim/// This is only used for loops with a "x != y" exit test. The exit condition is
4950221345Sdim/// now expressed as a single expression, V = x-y. So the exit test is
4951221345Sdim/// effectively V != 0.  We know and take advantage of the fact that this
4952221345Sdim/// expression only being used in a comparison by zero context.
4953204642SrdivackyScalarEvolution::BackedgeTakenInfo
4954204642SrdivackyScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
4955193323Sed  // If the value is a constant
4956193323Sed  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
4957193323Sed    // If the value is already zero, the branch will execute zero times.
4958193323Sed    if (C->getValue()->isZero()) return C;
4959195340Sed    return getCouldNotCompute();  // Otherwise it will loop infinitely.
4960193323Sed  }
4961193323Sed
4962193323Sed  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
4963193323Sed  if (!AddRec || AddRec->getLoop() != L)
4964195340Sed    return getCouldNotCompute();
4965193323Sed
4966218893Sdim  // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
4967218893Sdim  // the quadratic equation to solve it.
4968218893Sdim  if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
4969218893Sdim    std::pair<const SCEV *,const SCEV *> Roots =
4970218893Sdim      SolveQuadraticEquation(AddRec, *this);
4971193323Sed    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
4972193323Sed    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
4973218893Sdim    if (R1 && R2) {
4974193323Sed#if 0
4975201360Srdivacky      dbgs() << "HFTZ: " << *V << " - sol#1: " << *R1
4976193323Sed             << "  sol#2: " << *R2 << "\n";
4977193323Sed#endif
4978193323Sed      // Pick the smallest positive root value.
4979193323Sed      if (ConstantInt *CB =
4980218893Sdim          dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT,
4981218893Sdim                                                      R1->getValue(),
4982218893Sdim                                                      R2->getValue()))) {
4983193323Sed        if (CB->getZExtValue() == false)
4984193323Sed          std::swap(R1, R2);   // R1 is the minimum root now.
4985221345Sdim
4986193323Sed        // We can only use this value if the chrec ends up with an exact zero
4987193323Sed        // value at this index.  When solving for "X*X != 5", for example, we
4988193323Sed        // should not accept a root of 2.
4989198090Srdivacky        const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
4990193323Sed        if (Val->isZero())
4991193323Sed          return R1;  // We found a quadratic root!
4992193323Sed      }
4993193323Sed    }
4994218893Sdim    return getCouldNotCompute();
4995193323Sed  }
4996193323Sed
4997218893Sdim  // Otherwise we can only handle this if it is affine.
4998218893Sdim  if (!AddRec->isAffine())
4999218893Sdim    return getCouldNotCompute();
5000218893Sdim
5001218893Sdim  // If this is an affine expression, the execution count of this branch is
5002218893Sdim  // the minimum unsigned root of the following equation:
5003218893Sdim  //
5004218893Sdim  //     Start + Step*N = 0 (mod 2^BW)
5005218893Sdim  //
5006218893Sdim  // equivalent to:
5007218893Sdim  //
5008218893Sdim  //             Step*N = -Start (mod 2^BW)
5009218893Sdim  //
5010218893Sdim  // where BW is the common bit width of Start and Step.
5011218893Sdim
5012218893Sdim  // Get the initial value for the loop.
5013218893Sdim  const SCEV *Start = getSCEVAtScope(AddRec->getStart(), L->getParentLoop());
5014218893Sdim  const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1), L->getParentLoop());
5015218893Sdim
5016218893Sdim  // For now we handle only constant steps.
5017221345Sdim  //
5018221345Sdim  // TODO: Handle a nonconstant Step given AddRec<NUW>. If the
5019221345Sdim  // AddRec is NUW, then (in an unsigned sense) it cannot be counting up to wrap
5020221345Sdim  // to 0, it must be counting down to equal 0. Consequently, N = Start / -Step.
5021221345Sdim  // We have not yet seen any such cases.
5022218893Sdim  const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step);
5023218893Sdim  if (StepC == 0)
5024218893Sdim    return getCouldNotCompute();
5025218893Sdim
5026221345Sdim  // For positive steps (counting up until unsigned overflow):
5027221345Sdim  //   N = -Start/Step (as unsigned)
5028221345Sdim  // For negative steps (counting down to zero):
5029221345Sdim  //   N = Start/-Step
5030221345Sdim  // First compute the unsigned distance from zero in the direction of Step.
5031221345Sdim  bool CountDown = StepC->getValue()->getValue().isNegative();
5032221345Sdim  const SCEV *Distance = CountDown ? Start : getNegativeSCEV(Start);
5033218893Sdim
5034221345Sdim  // Handle unitary steps, which cannot wraparound.
5035221345Sdim  // 1*N = -Start; -1*N = Start (mod 2^BW), so:
5036221345Sdim  //   N = Distance (as unsigned)
5037221345Sdim  if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue())
5038221345Sdim    return Distance;
5039221345Sdim
5040221345Sdim  // If the recurrence is known not to wraparound, unsigned divide computes the
5041221345Sdim  // back edge count. We know that the value will either become zero (and thus
5042221345Sdim  // the loop terminates), that the loop will terminate through some other exit
5043221345Sdim  // condition first, or that the loop has undefined behavior.  This means
5044221345Sdim  // we can't "miss" the exit value, even with nonunit stride.
5045221345Sdim  //
5046221345Sdim  // FIXME: Prove that loops always exhibits *acceptable* undefined
5047221345Sdim  // behavior. Loops must exhibit defined behavior until a wrapped value is
5048221345Sdim  // actually used. So the trip count computed by udiv could be smaller than the
5049221345Sdim  // number of well-defined iterations.
5050221345Sdim  if (AddRec->getNoWrapFlags(SCEV::FlagNW))
5051221345Sdim    // FIXME: We really want an "isexact" bit for udiv.
5052221345Sdim    return getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
5053221345Sdim
5054218893Sdim  // Then, try to solve the above equation provided that Start is constant.
5055218893Sdim  if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
5056218893Sdim    return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
5057218893Sdim                                        -StartC->getValue()->getValue(),
5058218893Sdim                                        *this);
5059195340Sed  return getCouldNotCompute();
5060193323Sed}
5061193323Sed
5062193323Sed/// HowFarToNonZero - Return the number of times a backedge checking the
5063193323Sed/// specified value for nonzero will execute.  If not computable, return
5064193630Sed/// CouldNotCompute
5065204642SrdivackyScalarEvolution::BackedgeTakenInfo
5066204642SrdivackyScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
5067193323Sed  // Loops that look like: while (X == 0) are very strange indeed.  We don't
5068193323Sed  // handle them yet except for the trivial case.  This could be expanded in the
5069193323Sed  // future as needed.
5070193323Sed
5071193323Sed  // If the value is a constant, check to see if it is known to be non-zero
5072193323Sed  // already.  If so, the backedge will execute zero times.
5073193323Sed  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
5074193323Sed    if (!C->getValue()->isNullValue())
5075207618Srdivacky      return getConstant(C->getType(), 0);
5076195340Sed    return getCouldNotCompute();  // Otherwise it will loop infinitely.
5077193323Sed  }
5078193323Sed
5079193323Sed  // We could implement others, but I really doubt anyone writes loops like
5080193323Sed  // this, and if they did, they would already be constant folded.
5081195340Sed  return getCouldNotCompute();
5082193323Sed}
5083193323Sed
5084193323Sed/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
5085193323Sed/// (which may not be an immediate predecessor) which has exactly one
5086193323Sed/// successor from which BB is reachable, or null if no such block is
5087193323Sed/// found.
5088193323Sed///
5089207618Srdivackystd::pair<BasicBlock *, BasicBlock *>
5090193323SedScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
5091193323Sed  // If the block has a unique predecessor, then there is no path from the
5092193323Sed  // predecessor to the block that does not go through the direct edge
5093193323Sed  // from the predecessor to the block.
5094193323Sed  if (BasicBlock *Pred = BB->getSinglePredecessor())
5095207618Srdivacky    return std::make_pair(Pred, BB);
5096193323Sed
5097193323Sed  // A loop's header is defined to be a block that dominates the loop.
5098193323Sed  // If the header has a unique predecessor outside the loop, it must be
5099193323Sed  // a block that has exactly one successor that can reach the loop.
5100193323Sed  if (Loop *L = LI->getLoopFor(BB))
5101210299Sed    return std::make_pair(L->getLoopPredecessor(), L->getHeader());
5102193323Sed
5103207618Srdivacky  return std::pair<BasicBlock *, BasicBlock *>();
5104193323Sed}
5105193323Sed
5106194612Sed/// HasSameValue - SCEV structural equivalence is usually sufficient for
5107194612Sed/// testing whether two expressions are equal, however for the purposes of
5108194612Sed/// looking for a condition guarding a loop, it can be useful to be a little
5109194612Sed/// more general, since a front-end may have replicated the controlling
5110194612Sed/// expression.
5111194612Sed///
5112198090Srdivackystatic bool HasSameValue(const SCEV *A, const SCEV *B) {
5113194612Sed  // Quick check to see if they are the same SCEV.
5114194612Sed  if (A == B) return true;
5115194612Sed
5116194612Sed  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
5117194612Sed  // two different instructions with the same value. Check for this case.
5118194612Sed  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
5119194612Sed    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
5120194612Sed      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
5121194612Sed        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
5122198090Srdivacky          if (AI->isIdenticalTo(BI) && !AI->mayReadFromMemory())
5123194612Sed            return true;
5124194612Sed
5125194612Sed  // Otherwise assume they may have a different value.
5126194612Sed  return false;
5127194612Sed}
5128194612Sed
5129207618Srdivacky/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with
5130207618Srdivacky/// predicate Pred. Return true iff any changes were made.
5131207618Srdivacky///
5132207618Srdivackybool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
5133207618Srdivacky                                           const SCEV *&LHS, const SCEV *&RHS) {
5134207618Srdivacky  bool Changed = false;
5135207618Srdivacky
5136207618Srdivacky  // Canonicalize a constant to the right side.
5137207618Srdivacky  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
5138207618Srdivacky    // Check for both operands constant.
5139207618Srdivacky    if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
5140207618Srdivacky      if (ConstantExpr::getICmp(Pred,
5141207618Srdivacky                                LHSC->getValue(),
5142207618Srdivacky                                RHSC->getValue())->isNullValue())
5143207618Srdivacky        goto trivially_false;
5144207618Srdivacky      else
5145207618Srdivacky        goto trivially_true;
5146207618Srdivacky    }
5147207618Srdivacky    // Otherwise swap the operands to put the constant on the right.
5148207618Srdivacky    std::swap(LHS, RHS);
5149207618Srdivacky    Pred = ICmpInst::getSwappedPredicate(Pred);
5150207618Srdivacky    Changed = true;
5151207618Srdivacky  }
5152207618Srdivacky
5153207618Srdivacky  // If we're comparing an addrec with a value which is loop-invariant in the
5154207618Srdivacky  // addrec's loop, put the addrec on the left. Also make a dominance check,
5155207618Srdivacky  // as both operands could be addrecs loop-invariant in each other's loop.
5156207618Srdivacky  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS)) {
5157207618Srdivacky    const Loop *L = AR->getLoop();
5158218893Sdim    if (isLoopInvariant(LHS, L) && properlyDominates(LHS, L->getHeader())) {
5159207618Srdivacky      std::swap(LHS, RHS);
5160207618Srdivacky      Pred = ICmpInst::getSwappedPredicate(Pred);
5161207618Srdivacky      Changed = true;
5162207618Srdivacky    }
5163207618Srdivacky  }
5164207618Srdivacky
5165207618Srdivacky  // If there's a constant operand, canonicalize comparisons with boundary
5166207618Srdivacky  // cases, and canonicalize *-or-equal comparisons to regular comparisons.
5167207618Srdivacky  if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {
5168207618Srdivacky    const APInt &RA = RC->getValue()->getValue();
5169207618Srdivacky    switch (Pred) {
5170207618Srdivacky    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5171207618Srdivacky    case ICmpInst::ICMP_EQ:
5172207618Srdivacky    case ICmpInst::ICMP_NE:
5173207618Srdivacky      break;
5174207618Srdivacky    case ICmpInst::ICMP_UGE:
5175207618Srdivacky      if ((RA - 1).isMinValue()) {
5176207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5177207618Srdivacky        RHS = getConstant(RA - 1);
5178207618Srdivacky        Changed = true;
5179207618Srdivacky        break;
5180207618Srdivacky      }
5181207618Srdivacky      if (RA.isMaxValue()) {
5182207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5183207618Srdivacky        Changed = true;
5184207618Srdivacky        break;
5185207618Srdivacky      }
5186207618Srdivacky      if (RA.isMinValue()) goto trivially_true;
5187207618Srdivacky
5188207618Srdivacky      Pred = ICmpInst::ICMP_UGT;
5189207618Srdivacky      RHS = getConstant(RA - 1);
5190207618Srdivacky      Changed = true;
5191207618Srdivacky      break;
5192207618Srdivacky    case ICmpInst::ICMP_ULE:
5193207618Srdivacky      if ((RA + 1).isMaxValue()) {
5194207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5195207618Srdivacky        RHS = getConstant(RA + 1);
5196207618Srdivacky        Changed = true;
5197207618Srdivacky        break;
5198207618Srdivacky      }
5199207618Srdivacky      if (RA.isMinValue()) {
5200207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5201207618Srdivacky        Changed = true;
5202207618Srdivacky        break;
5203207618Srdivacky      }
5204207618Srdivacky      if (RA.isMaxValue()) goto trivially_true;
5205207618Srdivacky
5206207618Srdivacky      Pred = ICmpInst::ICMP_ULT;
5207207618Srdivacky      RHS = getConstant(RA + 1);
5208207618Srdivacky      Changed = true;
5209207618Srdivacky      break;
5210207618Srdivacky    case ICmpInst::ICMP_SGE:
5211207618Srdivacky      if ((RA - 1).isMinSignedValue()) {
5212207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5213207618Srdivacky        RHS = getConstant(RA - 1);
5214207618Srdivacky        Changed = true;
5215207618Srdivacky        break;
5216207618Srdivacky      }
5217207618Srdivacky      if (RA.isMaxSignedValue()) {
5218207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5219207618Srdivacky        Changed = true;
5220207618Srdivacky        break;
5221207618Srdivacky      }
5222207618Srdivacky      if (RA.isMinSignedValue()) goto trivially_true;
5223207618Srdivacky
5224207618Srdivacky      Pred = ICmpInst::ICMP_SGT;
5225207618Srdivacky      RHS = getConstant(RA - 1);
5226207618Srdivacky      Changed = true;
5227207618Srdivacky      break;
5228207618Srdivacky    case ICmpInst::ICMP_SLE:
5229207618Srdivacky      if ((RA + 1).isMaxSignedValue()) {
5230207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5231207618Srdivacky        RHS = getConstant(RA + 1);
5232207618Srdivacky        Changed = true;
5233207618Srdivacky        break;
5234207618Srdivacky      }
5235207618Srdivacky      if (RA.isMinSignedValue()) {
5236207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5237207618Srdivacky        Changed = true;
5238207618Srdivacky        break;
5239207618Srdivacky      }
5240207618Srdivacky      if (RA.isMaxSignedValue()) goto trivially_true;
5241207618Srdivacky
5242207618Srdivacky      Pred = ICmpInst::ICMP_SLT;
5243207618Srdivacky      RHS = getConstant(RA + 1);
5244207618Srdivacky      Changed = true;
5245207618Srdivacky      break;
5246207618Srdivacky    case ICmpInst::ICMP_UGT:
5247207618Srdivacky      if (RA.isMinValue()) {
5248207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5249207618Srdivacky        Changed = true;
5250207618Srdivacky        break;
5251207618Srdivacky      }
5252207618Srdivacky      if ((RA + 1).isMaxValue()) {
5253207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5254207618Srdivacky        RHS = getConstant(RA + 1);
5255207618Srdivacky        Changed = true;
5256207618Srdivacky        break;
5257207618Srdivacky      }
5258207618Srdivacky      if (RA.isMaxValue()) goto trivially_false;
5259207618Srdivacky      break;
5260207618Srdivacky    case ICmpInst::ICMP_ULT:
5261207618Srdivacky      if (RA.isMaxValue()) {
5262207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5263207618Srdivacky        Changed = true;
5264207618Srdivacky        break;
5265207618Srdivacky      }
5266207618Srdivacky      if ((RA - 1).isMinValue()) {
5267207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5268207618Srdivacky        RHS = getConstant(RA - 1);
5269207618Srdivacky        Changed = true;
5270207618Srdivacky        break;
5271207618Srdivacky      }
5272207618Srdivacky      if (RA.isMinValue()) goto trivially_false;
5273207618Srdivacky      break;
5274207618Srdivacky    case ICmpInst::ICMP_SGT:
5275207618Srdivacky      if (RA.isMinSignedValue()) {
5276207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5277207618Srdivacky        Changed = true;
5278207618Srdivacky        break;
5279207618Srdivacky      }
5280207618Srdivacky      if ((RA + 1).isMaxSignedValue()) {
5281207618Srdivacky        Pred = ICmpInst::ICMP_EQ;
5282207618Srdivacky        RHS = getConstant(RA + 1);
5283207618Srdivacky        Changed = true;
5284207618Srdivacky        break;
5285207618Srdivacky      }
5286207618Srdivacky      if (RA.isMaxSignedValue()) goto trivially_false;
5287207618Srdivacky      break;
5288207618Srdivacky    case ICmpInst::ICMP_SLT:
5289207618Srdivacky      if (RA.isMaxSignedValue()) {
5290207618Srdivacky        Pred = ICmpInst::ICMP_NE;
5291207618Srdivacky        Changed = true;
5292207618Srdivacky        break;
5293207618Srdivacky      }
5294207618Srdivacky      if ((RA - 1).isMinSignedValue()) {
5295207618Srdivacky       Pred = ICmpInst::ICMP_EQ;
5296207618Srdivacky       RHS = getConstant(RA - 1);
5297207618Srdivacky        Changed = true;
5298207618Srdivacky       break;
5299207618Srdivacky      }
5300207618Srdivacky      if (RA.isMinSignedValue()) goto trivially_false;
5301207618Srdivacky      break;
5302207618Srdivacky    }
5303207618Srdivacky  }
5304207618Srdivacky
5305207618Srdivacky  // Check for obvious equality.
5306207618Srdivacky  if (HasSameValue(LHS, RHS)) {
5307207618Srdivacky    if (ICmpInst::isTrueWhenEqual(Pred))
5308207618Srdivacky      goto trivially_true;
5309207618Srdivacky    if (ICmpInst::isFalseWhenEqual(Pred))
5310207618Srdivacky      goto trivially_false;
5311207618Srdivacky  }
5312207618Srdivacky
5313207618Srdivacky  // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
5314207618Srdivacky  // adding or subtracting 1 from one of the operands.
5315207618Srdivacky  switch (Pred) {
5316207618Srdivacky  case ICmpInst::ICMP_SLE:
5317207618Srdivacky    if (!getSignedRange(RHS).getSignedMax().isMaxSignedValue()) {
5318207618Srdivacky      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5319221345Sdim                       SCEV::FlagNSW);
5320207618Srdivacky      Pred = ICmpInst::ICMP_SLT;
5321207618Srdivacky      Changed = true;
5322207618Srdivacky    } else if (!getSignedRange(LHS).getSignedMin().isMinSignedValue()) {
5323207618Srdivacky      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5324221345Sdim                       SCEV::FlagNSW);
5325207618Srdivacky      Pred = ICmpInst::ICMP_SLT;
5326207618Srdivacky      Changed = true;
5327207618Srdivacky    }
5328207618Srdivacky    break;
5329207618Srdivacky  case ICmpInst::ICMP_SGE:
5330207618Srdivacky    if (!getSignedRange(RHS).getSignedMin().isMinSignedValue()) {
5331207618Srdivacky      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5332221345Sdim                       SCEV::FlagNSW);
5333207618Srdivacky      Pred = ICmpInst::ICMP_SGT;
5334207618Srdivacky      Changed = true;
5335207618Srdivacky    } else if (!getSignedRange(LHS).getSignedMax().isMaxSignedValue()) {
5336207618Srdivacky      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5337221345Sdim                       SCEV::FlagNSW);
5338207618Srdivacky      Pred = ICmpInst::ICMP_SGT;
5339207618Srdivacky      Changed = true;
5340207618Srdivacky    }
5341207618Srdivacky    break;
5342207618Srdivacky  case ICmpInst::ICMP_ULE:
5343207618Srdivacky    if (!getUnsignedRange(RHS).getUnsignedMax().isMaxValue()) {
5344207618Srdivacky      RHS = getAddExpr(getConstant(RHS->getType(), 1, true), RHS,
5345221345Sdim                       SCEV::FlagNUW);
5346207618Srdivacky      Pred = ICmpInst::ICMP_ULT;
5347207618Srdivacky      Changed = true;
5348207618Srdivacky    } else if (!getUnsignedRange(LHS).getUnsignedMin().isMinValue()) {
5349207618Srdivacky      LHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), LHS,
5350221345Sdim                       SCEV::FlagNUW);
5351207618Srdivacky      Pred = ICmpInst::ICMP_ULT;
5352207618Srdivacky      Changed = true;
5353207618Srdivacky    }
5354207618Srdivacky    break;
5355207618Srdivacky  case ICmpInst::ICMP_UGE:
5356207618Srdivacky    if (!getUnsignedRange(RHS).getUnsignedMin().isMinValue()) {
5357207618Srdivacky      RHS = getAddExpr(getConstant(RHS->getType(), (uint64_t)-1, true), RHS,
5358221345Sdim                       SCEV::FlagNUW);
5359207618Srdivacky      Pred = ICmpInst::ICMP_UGT;
5360207618Srdivacky      Changed = true;
5361207618Srdivacky    } else if (!getUnsignedRange(LHS).getUnsignedMax().isMaxValue()) {
5362207618Srdivacky      LHS = getAddExpr(getConstant(RHS->getType(), 1, true), LHS,
5363221345Sdim                       SCEV::FlagNUW);
5364207618Srdivacky      Pred = ICmpInst::ICMP_UGT;
5365207618Srdivacky      Changed = true;
5366207618Srdivacky    }
5367207618Srdivacky    break;
5368207618Srdivacky  default:
5369207618Srdivacky    break;
5370207618Srdivacky  }
5371207618Srdivacky
5372207618Srdivacky  // TODO: More simplifications are possible here.
5373207618Srdivacky
5374207618Srdivacky  return Changed;
5375207618Srdivacky
5376207618Srdivackytrivially_true:
5377207618Srdivacky  // Return 0 == 0.
5378218893Sdim  LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5379207618Srdivacky  Pred = ICmpInst::ICMP_EQ;
5380207618Srdivacky  return true;
5381207618Srdivacky
5382207618Srdivackytrivially_false:
5383207618Srdivacky  // Return 0 != 0.
5384218893Sdim  LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
5385207618Srdivacky  Pred = ICmpInst::ICMP_NE;
5386207618Srdivacky  return true;
5387207618Srdivacky}
5388207618Srdivacky
5389198090Srdivackybool ScalarEvolution::isKnownNegative(const SCEV *S) {
5390198090Srdivacky  return getSignedRange(S).getSignedMax().isNegative();
5391198090Srdivacky}
5392198090Srdivacky
5393198090Srdivackybool ScalarEvolution::isKnownPositive(const SCEV *S) {
5394198090Srdivacky  return getSignedRange(S).getSignedMin().isStrictlyPositive();
5395198090Srdivacky}
5396198090Srdivacky
5397198090Srdivackybool ScalarEvolution::isKnownNonNegative(const SCEV *S) {
5398198090Srdivacky  return !getSignedRange(S).getSignedMin().isNegative();
5399198090Srdivacky}
5400198090Srdivacky
5401198090Srdivackybool ScalarEvolution::isKnownNonPositive(const SCEV *S) {
5402198090Srdivacky  return !getSignedRange(S).getSignedMax().isStrictlyPositive();
5403198090Srdivacky}
5404198090Srdivacky
5405198090Srdivackybool ScalarEvolution::isKnownNonZero(const SCEV *S) {
5406198090Srdivacky  return isKnownNegative(S) || isKnownPositive(S);
5407198090Srdivacky}
5408198090Srdivacky
5409198090Srdivackybool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
5410198090Srdivacky                                       const SCEV *LHS, const SCEV *RHS) {
5411207618Srdivacky  // Canonicalize the inputs first.
5412207618Srdivacky  (void)SimplifyICmpOperands(Pred, LHS, RHS);
5413198090Srdivacky
5414207618Srdivacky  // If LHS or RHS is an addrec, check to see if the condition is true in
5415207618Srdivacky  // every iteration of the loop.
5416207618Srdivacky  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
5417207618Srdivacky    if (isLoopEntryGuardedByCond(
5418207618Srdivacky          AR->getLoop(), Pred, AR->getStart(), RHS) &&
5419207618Srdivacky        isLoopBackedgeGuardedByCond(
5420207618Srdivacky          AR->getLoop(), Pred, AR->getPostIncExpr(*this), RHS))
5421207618Srdivacky      return true;
5422207618Srdivacky  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(RHS))
5423207618Srdivacky    if (isLoopEntryGuardedByCond(
5424207618Srdivacky          AR->getLoop(), Pred, LHS, AR->getStart()) &&
5425207618Srdivacky        isLoopBackedgeGuardedByCond(
5426207618Srdivacky          AR->getLoop(), Pred, LHS, AR->getPostIncExpr(*this)))
5427207618Srdivacky      return true;
5428207618Srdivacky
5429207618Srdivacky  // Otherwise see what can be done with known constant ranges.
5430207618Srdivacky  return isKnownPredicateWithRanges(Pred, LHS, RHS);
5431207618Srdivacky}
5432207618Srdivacky
5433207618Srdivackybool
5434207618SrdivackyScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred,
5435207618Srdivacky                                            const SCEV *LHS, const SCEV *RHS) {
5436198090Srdivacky  if (HasSameValue(LHS, RHS))
5437198090Srdivacky    return ICmpInst::isTrueWhenEqual(Pred);
5438198090Srdivacky
5439207618Srdivacky  // This code is split out from isKnownPredicate because it is called from
5440207618Srdivacky  // within isLoopEntryGuardedByCond.
5441198090Srdivacky  switch (Pred) {
5442198090Srdivacky  default:
5443198090Srdivacky    llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5444198090Srdivacky    break;
5445198090Srdivacky  case ICmpInst::ICMP_SGT:
5446198090Srdivacky    Pred = ICmpInst::ICMP_SLT;
5447198090Srdivacky    std::swap(LHS, RHS);
5448198090Srdivacky  case ICmpInst::ICMP_SLT: {
5449198090Srdivacky    ConstantRange LHSRange = getSignedRange(LHS);
5450198090Srdivacky    ConstantRange RHSRange = getSignedRange(RHS);
5451198090Srdivacky    if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin()))
5452198090Srdivacky      return true;
5453198090Srdivacky    if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax()))
5454198090Srdivacky      return false;
5455198090Srdivacky    break;
5456198090Srdivacky  }
5457198090Srdivacky  case ICmpInst::ICMP_SGE:
5458198090Srdivacky    Pred = ICmpInst::ICMP_SLE;
5459198090Srdivacky    std::swap(LHS, RHS);
5460198090Srdivacky  case ICmpInst::ICMP_SLE: {
5461198090Srdivacky    ConstantRange LHSRange = getSignedRange(LHS);
5462198090Srdivacky    ConstantRange RHSRange = getSignedRange(RHS);
5463198090Srdivacky    if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin()))
5464198090Srdivacky      return true;
5465198090Srdivacky    if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax()))
5466198090Srdivacky      return false;
5467198090Srdivacky    break;
5468198090Srdivacky  }
5469198090Srdivacky  case ICmpInst::ICMP_UGT:
5470198090Srdivacky    Pred = ICmpInst::ICMP_ULT;
5471198090Srdivacky    std::swap(LHS, RHS);
5472198090Srdivacky  case ICmpInst::ICMP_ULT: {
5473198090Srdivacky    ConstantRange LHSRange = getUnsignedRange(LHS);
5474198090Srdivacky    ConstantRange RHSRange = getUnsignedRange(RHS);
5475198090Srdivacky    if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin()))
5476198090Srdivacky      return true;
5477198090Srdivacky    if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax()))
5478198090Srdivacky      return false;
5479198090Srdivacky    break;
5480198090Srdivacky  }
5481198090Srdivacky  case ICmpInst::ICMP_UGE:
5482198090Srdivacky    Pred = ICmpInst::ICMP_ULE;
5483198090Srdivacky    std::swap(LHS, RHS);
5484198090Srdivacky  case ICmpInst::ICMP_ULE: {
5485198090Srdivacky    ConstantRange LHSRange = getUnsignedRange(LHS);
5486198090Srdivacky    ConstantRange RHSRange = getUnsignedRange(RHS);
5487198090Srdivacky    if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin()))
5488198090Srdivacky      return true;
5489198090Srdivacky    if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax()))
5490198090Srdivacky      return false;
5491198090Srdivacky    break;
5492198090Srdivacky  }
5493198090Srdivacky  case ICmpInst::ICMP_NE: {
5494198090Srdivacky    if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet())
5495198090Srdivacky      return true;
5496198090Srdivacky    if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet())
5497198090Srdivacky      return true;
5498198090Srdivacky
5499198090Srdivacky    const SCEV *Diff = getMinusSCEV(LHS, RHS);
5500198090Srdivacky    if (isKnownNonZero(Diff))
5501198090Srdivacky      return true;
5502198090Srdivacky    break;
5503198090Srdivacky  }
5504198090Srdivacky  case ICmpInst::ICMP_EQ:
5505198090Srdivacky    // The check at the top of the function catches the case where
5506198090Srdivacky    // the values are known to be equal.
5507198090Srdivacky    break;
5508198090Srdivacky  }
5509198090Srdivacky  return false;
5510198090Srdivacky}
5511198090Srdivacky
5512198090Srdivacky/// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is
5513198090Srdivacky/// protected by a conditional between LHS and RHS.  This is used to
5514198090Srdivacky/// to eliminate casts.
5515198090Srdivackybool
5516198090SrdivackyScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
5517198090Srdivacky                                             ICmpInst::Predicate Pred,
5518198090Srdivacky                                             const SCEV *LHS, const SCEV *RHS) {
5519193323Sed  // Interpret a null as meaning no loop, where there is obviously no guard
5520193323Sed  // (interprocedural conditions notwithstanding).
5521198090Srdivacky  if (!L) return true;
5522198090Srdivacky
5523198090Srdivacky  BasicBlock *Latch = L->getLoopLatch();
5524198090Srdivacky  if (!Latch)
5525198090Srdivacky    return false;
5526198090Srdivacky
5527198090Srdivacky  BranchInst *LoopContinuePredicate =
5528198090Srdivacky    dyn_cast<BranchInst>(Latch->getTerminator());
5529198090Srdivacky  if (!LoopContinuePredicate ||
5530198090Srdivacky      LoopContinuePredicate->isUnconditional())
5531198090Srdivacky    return false;
5532198090Srdivacky
5533212904Sdim  return isImpliedCond(Pred, LHS, RHS,
5534212904Sdim                       LoopContinuePredicate->getCondition(),
5535198090Srdivacky                       LoopContinuePredicate->getSuccessor(0) != L->getHeader());
5536198090Srdivacky}
5537198090Srdivacky
5538207618Srdivacky/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected
5539198090Srdivacky/// by a conditional between LHS and RHS.  This is used to help avoid max
5540198090Srdivacky/// expressions in loop trip counts, and to eliminate casts.
5541198090Srdivackybool
5542207618SrdivackyScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
5543207618Srdivacky                                          ICmpInst::Predicate Pred,
5544207618Srdivacky                                          const SCEV *LHS, const SCEV *RHS) {
5545198090Srdivacky  // Interpret a null as meaning no loop, where there is obviously no guard
5546198090Srdivacky  // (interprocedural conditions notwithstanding).
5547193323Sed  if (!L) return false;
5548193323Sed
5549193323Sed  // Starting at the loop predecessor, climb up the predecessor chain, as long
5550193323Sed  // as there are predecessors that can be found that have unique successors
5551193323Sed  // leading to the original header.
5552207618Srdivacky  for (std::pair<BasicBlock *, BasicBlock *>
5553210299Sed         Pair(L->getLoopPredecessor(), L->getHeader());
5554207618Srdivacky       Pair.first;
5555207618Srdivacky       Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
5556193323Sed
5557193323Sed    BranchInst *LoopEntryPredicate =
5558207618Srdivacky      dyn_cast<BranchInst>(Pair.first->getTerminator());
5559193323Sed    if (!LoopEntryPredicate ||
5560193323Sed        LoopEntryPredicate->isUnconditional())
5561193323Sed      continue;
5562193323Sed
5563212904Sdim    if (isImpliedCond(Pred, LHS, RHS,
5564212904Sdim                      LoopEntryPredicate->getCondition(),
5565207618Srdivacky                      LoopEntryPredicate->getSuccessor(0) != Pair.second))
5566195098Sed      return true;
5567195098Sed  }
5568193323Sed
5569195098Sed  return false;
5570195098Sed}
5571193323Sed
5572198090Srdivacky/// isImpliedCond - Test whether the condition described by Pred, LHS,
5573198090Srdivacky/// and RHS is true whenever the given Cond value evaluates to true.
5574212904Sdimbool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
5575198090Srdivacky                                    const SCEV *LHS, const SCEV *RHS,
5576212904Sdim                                    Value *FoundCondValue,
5577198090Srdivacky                                    bool Inverse) {
5578204642Srdivacky  // Recursively handle And and Or conditions.
5579212904Sdim  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
5580195098Sed    if (BO->getOpcode() == Instruction::And) {
5581195098Sed      if (!Inverse)
5582212904Sdim        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
5583212904Sdim               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
5584195098Sed    } else if (BO->getOpcode() == Instruction::Or) {
5585195098Sed      if (Inverse)
5586212904Sdim        return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
5587212904Sdim               isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
5588195098Sed    }
5589195098Sed  }
5590195098Sed
5591212904Sdim  ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
5592195098Sed  if (!ICI) return false;
5593195098Sed
5594198090Srdivacky  // Bail if the ICmp's operands' types are wider than the needed type
5595198090Srdivacky  // before attempting to call getSCEV on them. This avoids infinite
5596198090Srdivacky  // recursion, since the analysis of widening casts can require loop
5597198090Srdivacky  // exit condition information for overflow checking, which would
5598198090Srdivacky  // lead back here.
5599198090Srdivacky  if (getTypeSizeInBits(LHS->getType()) <
5600198090Srdivacky      getTypeSizeInBits(ICI->getOperand(0)->getType()))
5601198090Srdivacky    return false;
5602198090Srdivacky
5603195098Sed  // Now that we found a conditional branch that dominates the loop, check to
5604195098Sed  // see if it is the comparison we are looking for.
5605198090Srdivacky  ICmpInst::Predicate FoundPred;
5606195098Sed  if (Inverse)
5607198090Srdivacky    FoundPred = ICI->getInversePredicate();
5608195098Sed  else
5609198090Srdivacky    FoundPred = ICI->getPredicate();
5610195098Sed
5611198090Srdivacky  const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
5612198090Srdivacky  const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
5613198090Srdivacky
5614198090Srdivacky  // Balance the types. The case where FoundLHS' type is wider than
5615198090Srdivacky  // LHS' type is checked for above.
5616198090Srdivacky  if (getTypeSizeInBits(LHS->getType()) >
5617198090Srdivacky      getTypeSizeInBits(FoundLHS->getType())) {
5618198090Srdivacky    if (CmpInst::isSigned(Pred)) {
5619198090Srdivacky      FoundLHS = getSignExtendExpr(FoundLHS, LHS->getType());
5620198090Srdivacky      FoundRHS = getSignExtendExpr(FoundRHS, LHS->getType());
5621198090Srdivacky    } else {
5622198090Srdivacky      FoundLHS = getZeroExtendExpr(FoundLHS, LHS->getType());
5623198090Srdivacky      FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
5624198090Srdivacky    }
5625198090Srdivacky  }
5626198090Srdivacky
5627198090Srdivacky  // Canonicalize the query to match the way instcombine will have
5628198090Srdivacky  // canonicalized the comparison.
5629207618Srdivacky  if (SimplifyICmpOperands(Pred, LHS, RHS))
5630207618Srdivacky    if (LHS == RHS)
5631207618Srdivacky      return CmpInst::isTrueWhenEqual(Pred);
5632207618Srdivacky  if (SimplifyICmpOperands(FoundPred, FoundLHS, FoundRHS))
5633207618Srdivacky    if (FoundLHS == FoundRHS)
5634207618Srdivacky      return CmpInst::isFalseWhenEqual(Pred);
5635193323Sed
5636198090Srdivacky  // Check to see if we can make the LHS or RHS match.
5637198090Srdivacky  if (LHS == FoundRHS || RHS == FoundLHS) {
5638198090Srdivacky    if (isa<SCEVConstant>(RHS)) {
5639198090Srdivacky      std::swap(FoundLHS, FoundRHS);
5640198090Srdivacky      FoundPred = ICmpInst::getSwappedPredicate(FoundPred);
5641198090Srdivacky    } else {
5642198090Srdivacky      std::swap(LHS, RHS);
5643198090Srdivacky      Pred = ICmpInst::getSwappedPredicate(Pred);
5644198090Srdivacky    }
5645198090Srdivacky  }
5646193323Sed
5647198090Srdivacky  // Check whether the found predicate is the same as the desired predicate.
5648198090Srdivacky  if (FoundPred == Pred)
5649198090Srdivacky    return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
5650198090Srdivacky
5651198090Srdivacky  // Check whether swapping the found predicate makes it the same as the
5652198090Srdivacky  // desired predicate.
5653198090Srdivacky  if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
5654198090Srdivacky    if (isa<SCEVConstant>(RHS))
5655198090Srdivacky      return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
5656198090Srdivacky    else
5657198090Srdivacky      return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
5658198090Srdivacky                                   RHS, LHS, FoundLHS, FoundRHS);
5659198090Srdivacky  }
5660198090Srdivacky
5661198090Srdivacky  // Check whether the actual condition is beyond sufficient.
5662198090Srdivacky  if (FoundPred == ICmpInst::ICMP_EQ)
5663198090Srdivacky    if (ICmpInst::isTrueWhenEqual(Pred))
5664198090Srdivacky      if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
5665198090Srdivacky        return true;
5666198090Srdivacky  if (Pred == ICmpInst::ICMP_NE)
5667198090Srdivacky    if (!ICmpInst::isTrueWhenEqual(FoundPred))
5668198090Srdivacky      if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
5669198090Srdivacky        return true;
5670198090Srdivacky
5671198090Srdivacky  // Otherwise assume the worst.
5672198090Srdivacky  return false;
5673193323Sed}
5674193323Sed
5675198090Srdivacky/// isImpliedCondOperands - Test whether the condition described by Pred,
5676204642Srdivacky/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS,
5677198090Srdivacky/// and FoundRHS is true.
5678198090Srdivackybool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
5679198090Srdivacky                                            const SCEV *LHS, const SCEV *RHS,
5680198090Srdivacky                                            const SCEV *FoundLHS,
5681198090Srdivacky                                            const SCEV *FoundRHS) {
5682198090Srdivacky  return isImpliedCondOperandsHelper(Pred, LHS, RHS,
5683198090Srdivacky                                     FoundLHS, FoundRHS) ||
5684198090Srdivacky         // ~x < ~y --> x > y
5685198090Srdivacky         isImpliedCondOperandsHelper(Pred, LHS, RHS,
5686198090Srdivacky                                     getNotSCEV(FoundRHS),
5687198090Srdivacky                                     getNotSCEV(FoundLHS));
5688198090Srdivacky}
5689198090Srdivacky
5690198090Srdivacky/// isImpliedCondOperandsHelper - Test whether the condition described by
5691204642Srdivacky/// Pred, LHS, and RHS is true whenever the condition described by Pred,
5692198090Srdivacky/// FoundLHS, and FoundRHS is true.
5693198090Srdivackybool
5694198090SrdivackyScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
5695198090Srdivacky                                             const SCEV *LHS, const SCEV *RHS,
5696198090Srdivacky                                             const SCEV *FoundLHS,
5697198090Srdivacky                                             const SCEV *FoundRHS) {
5698198090Srdivacky  switch (Pred) {
5699198090Srdivacky  default: llvm_unreachable("Unexpected ICmpInst::Predicate value!");
5700198090Srdivacky  case ICmpInst::ICMP_EQ:
5701198090Srdivacky  case ICmpInst::ICMP_NE:
5702198090Srdivacky    if (HasSameValue(LHS, FoundLHS) && HasSameValue(RHS, FoundRHS))
5703198090Srdivacky      return true;
5704198090Srdivacky    break;
5705198090Srdivacky  case ICmpInst::ICMP_SLT:
5706198090Srdivacky  case ICmpInst::ICMP_SLE:
5707207618Srdivacky    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
5708207618Srdivacky        isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS))
5709198090Srdivacky      return true;
5710198090Srdivacky    break;
5711198090Srdivacky  case ICmpInst::ICMP_SGT:
5712198090Srdivacky  case ICmpInst::ICMP_SGE:
5713207618Srdivacky    if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
5714207618Srdivacky        isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS))
5715198090Srdivacky      return true;
5716198090Srdivacky    break;
5717198090Srdivacky  case ICmpInst::ICMP_ULT:
5718198090Srdivacky  case ICmpInst::ICMP_ULE:
5719207618Srdivacky    if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
5720207618Srdivacky        isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS))
5721198090Srdivacky      return true;
5722198090Srdivacky    break;
5723198090Srdivacky  case ICmpInst::ICMP_UGT:
5724198090Srdivacky  case ICmpInst::ICMP_UGE:
5725207618Srdivacky    if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
5726207618Srdivacky        isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS))
5727198090Srdivacky      return true;
5728198090Srdivacky    break;
5729198090Srdivacky  }
5730198090Srdivacky
5731198090Srdivacky  return false;
5732198090Srdivacky}
5733198090Srdivacky
5734194612Sed/// getBECount - Subtract the end and start values and divide by the step,
5735194612Sed/// rounding up, to get the number of times the backedge is executed. Return
5736194612Sed/// CouldNotCompute if an intermediate computation overflows.
5737198090Srdivackyconst SCEV *ScalarEvolution::getBECount(const SCEV *Start,
5738198090Srdivacky                                        const SCEV *End,
5739198090Srdivacky                                        const SCEV *Step,
5740198090Srdivacky                                        bool NoWrap) {
5741203954Srdivacky  assert(!isKnownNegative(Step) &&
5742203954Srdivacky         "This code doesn't handle negative strides yet!");
5743203954Srdivacky
5744194612Sed  const Type *Ty = Start->getType();
5745221345Sdim
5746221345Sdim  // When Start == End, we have an exact BECount == 0. Short-circuit this case
5747221345Sdim  // here because SCEV may not be able to determine that the unsigned division
5748221345Sdim  // after rounding is zero.
5749221345Sdim  if (Start == End)
5750221345Sdim    return getConstant(Ty, 0);
5751221345Sdim
5752207618Srdivacky  const SCEV *NegOne = getConstant(Ty, (uint64_t)-1);
5753198090Srdivacky  const SCEV *Diff = getMinusSCEV(End, Start);
5754198090Srdivacky  const SCEV *RoundUp = getAddExpr(Step, NegOne);
5755194612Sed
5756194612Sed  // Add an adjustment to the difference between End and Start so that
5757194612Sed  // the division will effectively round up.
5758198090Srdivacky  const SCEV *Add = getAddExpr(Diff, RoundUp);
5759194612Sed
5760198090Srdivacky  if (!NoWrap) {
5761198090Srdivacky    // Check Add for unsigned overflow.
5762198090Srdivacky    // TODO: More sophisticated things could be done here.
5763198090Srdivacky    const Type *WideTy = IntegerType::get(getContext(),
5764198090Srdivacky                                          getTypeSizeInBits(Ty) + 1);
5765198090Srdivacky    const SCEV *EDiff = getZeroExtendExpr(Diff, WideTy);
5766198090Srdivacky    const SCEV *ERoundUp = getZeroExtendExpr(RoundUp, WideTy);
5767198090Srdivacky    const SCEV *OperandExtendedAdd = getAddExpr(EDiff, ERoundUp);
5768198090Srdivacky    if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
5769198090Srdivacky      return getCouldNotCompute();
5770198090Srdivacky  }
5771194612Sed
5772194612Sed  return getUDivExpr(Add, Step);
5773194612Sed}
5774194612Sed
5775193323Sed/// HowManyLessThans - Return the number of times a backedge containing the
5776193323Sed/// specified less-than comparison will execute.  If not computable, return
5777193630Sed/// CouldNotCompute.
5778195098SedScalarEvolution::BackedgeTakenInfo
5779195098SedScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
5780195098Sed                                  const Loop *L, bool isSigned) {
5781193323Sed  // Only handle:  "ADDREC < LoopInvariant".
5782218893Sdim  if (!isLoopInvariant(RHS, L)) return getCouldNotCompute();
5783193323Sed
5784193323Sed  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
5785193323Sed  if (!AddRec || AddRec->getLoop() != L)
5786195340Sed    return getCouldNotCompute();
5787193323Sed
5788198090Srdivacky  // Check to see if we have a flag which makes analysis easy.
5789221345Sdim  bool NoWrap = isSigned ? AddRec->getNoWrapFlags(SCEV::FlagNSW) :
5790221345Sdim                           AddRec->getNoWrapFlags(SCEV::FlagNUW);
5791198090Srdivacky
5792193323Sed  if (AddRec->isAffine()) {
5793193323Sed    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
5794198090Srdivacky    const SCEV *Step = AddRec->getStepRecurrence(*this);
5795193323Sed
5796203954Srdivacky    if (Step->isZero())
5797195340Sed      return getCouldNotCompute();
5798203954Srdivacky    if (Step->isOne()) {
5799193323Sed      // With unit stride, the iteration never steps past the limit value.
5800203954Srdivacky    } else if (isKnownPositive(Step)) {
5801203954Srdivacky      // Test whether a positive iteration can step past the limit
5802203954Srdivacky      // value and past the maximum value for its type in a single step.
5803203954Srdivacky      // Note that it's not sufficient to check NoWrap here, because even
5804203954Srdivacky      // though the value after a wrap is undefined, it's not undefined
5805203954Srdivacky      // behavior, so if wrap does occur, the loop could either terminate or
5806203954Srdivacky      // loop infinitely, but in either case, the loop is guaranteed to
5807203954Srdivacky      // iterate at least until the iteration where the wrapping occurs.
5808207618Srdivacky      const SCEV *One = getConstant(Step->getType(), 1);
5809203954Srdivacky      if (isSigned) {
5810203954Srdivacky        APInt Max = APInt::getSignedMaxValue(BitWidth);
5811203954Srdivacky        if ((Max - getSignedRange(getMinusSCEV(Step, One)).getSignedMax())
5812203954Srdivacky              .slt(getSignedRange(RHS).getSignedMax()))
5813203954Srdivacky          return getCouldNotCompute();
5814203954Srdivacky      } else {
5815203954Srdivacky        APInt Max = APInt::getMaxValue(BitWidth);
5816203954Srdivacky        if ((Max - getUnsignedRange(getMinusSCEV(Step, One)).getUnsignedMax())
5817203954Srdivacky              .ult(getUnsignedRange(RHS).getUnsignedMax()))
5818203954Srdivacky          return getCouldNotCompute();
5819203954Srdivacky      }
5820193323Sed    } else
5821203954Srdivacky      // TODO: Handle negative strides here and below.
5822195340Sed      return getCouldNotCompute();
5823193323Sed
5824193323Sed    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
5825193323Sed    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
5826193323Sed    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
5827193323Sed    // treat m-n as signed nor unsigned due to overflow possibility.
5828193323Sed
5829193323Sed    // First, we get the value of the LHS in the first iteration: n
5830198090Srdivacky    const SCEV *Start = AddRec->getOperand(0);
5831193323Sed
5832193323Sed    // Determine the minimum constant start value.
5833198090Srdivacky    const SCEV *MinStart = getConstant(isSigned ?
5834198090Srdivacky      getSignedRange(Start).getSignedMin() :
5835198090Srdivacky      getUnsignedRange(Start).getUnsignedMin());
5836193323Sed
5837193323Sed    // If we know that the condition is true in order to enter the loop,
5838193323Sed    // then we know that it will run exactly (m-n)/s times. Otherwise, we
5839193323Sed    // only know that it will execute (max(m,n)-n)/s times. In both cases,
5840193323Sed    // the division must round up.
5841198090Srdivacky    const SCEV *End = RHS;
5842207618Srdivacky    if (!isLoopEntryGuardedByCond(L,
5843207618Srdivacky                                  isSigned ? ICmpInst::ICMP_SLT :
5844207618Srdivacky                                             ICmpInst::ICMP_ULT,
5845207618Srdivacky                                  getMinusSCEV(Start, Step), RHS))
5846193323Sed      End = isSigned ? getSMaxExpr(RHS, Start)
5847193323Sed                     : getUMaxExpr(RHS, Start);
5848193323Sed
5849193323Sed    // Determine the maximum constant end value.
5850198090Srdivacky    const SCEV *MaxEnd = getConstant(isSigned ?
5851198090Srdivacky      getSignedRange(End).getSignedMax() :
5852198090Srdivacky      getUnsignedRange(End).getUnsignedMax());
5853193323Sed
5854203954Srdivacky    // If MaxEnd is within a step of the maximum integer value in its type,
5855203954Srdivacky    // adjust it down to the minimum value which would produce the same effect.
5856204642Srdivacky    // This allows the subsequent ceiling division of (N+(step-1))/step to
5857203954Srdivacky    // compute the correct value.
5858203954Srdivacky    const SCEV *StepMinusOne = getMinusSCEV(Step,
5859207618Srdivacky                                            getConstant(Step->getType(), 1));
5860203954Srdivacky    MaxEnd = isSigned ?
5861203954Srdivacky      getSMinExpr(MaxEnd,
5862203954Srdivacky                  getMinusSCEV(getConstant(APInt::getSignedMaxValue(BitWidth)),
5863203954Srdivacky                               StepMinusOne)) :
5864203954Srdivacky      getUMinExpr(MaxEnd,
5865203954Srdivacky                  getMinusSCEV(getConstant(APInt::getMaxValue(BitWidth)),
5866203954Srdivacky                               StepMinusOne));
5867203954Srdivacky
5868193323Sed    // Finally, we subtract these two values and divide, rounding up, to get
5869193323Sed    // the number of times the backedge is executed.
5870198090Srdivacky    const SCEV *BECount = getBECount(Start, End, Step, NoWrap);
5871193323Sed
5872193323Sed    // The maximum backedge count is similar, except using the minimum start
5873193323Sed    // value and the maximum end value.
5874221345Sdim    // If we already have an exact constant BECount, use it instead.
5875221345Sdim    const SCEV *MaxBECount = isa<SCEVConstant>(BECount) ? BECount
5876221345Sdim      : getBECount(MinStart, MaxEnd, Step, NoWrap);
5877193323Sed
5878221345Sdim    // If the stride is nonconstant, and NoWrap == true, then
5879221345Sdim    // getBECount(MinStart, MaxEnd) may not compute. This would result in an
5880221345Sdim    // exact BECount and invalid MaxBECount, which should be avoided to catch
5881221345Sdim    // more optimization opportunities.
5882221345Sdim    if (isa<SCEVCouldNotCompute>(MaxBECount))
5883221345Sdim      MaxBECount = BECount;
5884221345Sdim
5885193323Sed    return BackedgeTakenInfo(BECount, MaxBECount);
5886193323Sed  }
5887193323Sed
5888195340Sed  return getCouldNotCompute();
5889193323Sed}
5890193323Sed
5891193323Sed/// getNumIterationsInRange - Return the number of iterations of this loop that
5892193323Sed/// produce values in the specified constant range.  Another way of looking at
5893193323Sed/// this is that it returns the first iteration number where the value is not in
5894193323Sed/// the condition, thus computing the exit count. If the iteration count can't
5895193323Sed/// be computed, an instance of SCEVCouldNotCompute is returned.
5896198090Srdivackyconst SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
5897195098Sed                                                    ScalarEvolution &SE) const {
5898193323Sed  if (Range.isFullSet())  // Infinite loop.
5899193323Sed    return SE.getCouldNotCompute();
5900193323Sed
5901193323Sed  // If the start is a non-zero constant, shift the range to simplify things.
5902193323Sed  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
5903193323Sed    if (!SC->getValue()->isZero()) {
5904198090Srdivacky      SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
5905207618Srdivacky      Operands[0] = SE.getConstant(SC->getType(), 0);
5906221345Sdim      const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
5907221345Sdim                                             getNoWrapFlags(FlagNW));
5908193323Sed      if (const SCEVAddRecExpr *ShiftedAddRec =
5909193323Sed            dyn_cast<SCEVAddRecExpr>(Shifted))
5910193323Sed        return ShiftedAddRec->getNumIterationsInRange(
5911193323Sed                           Range.subtract(SC->getValue()->getValue()), SE);
5912193323Sed      // This is strange and shouldn't happen.
5913193323Sed      return SE.getCouldNotCompute();
5914193323Sed    }
5915193323Sed
5916193323Sed  // The only time we can solve this is when we have all constant indices.
5917193323Sed  // Otherwise, we cannot determine the overflow conditions.
5918193323Sed  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
5919193323Sed    if (!isa<SCEVConstant>(getOperand(i)))
5920193323Sed      return SE.getCouldNotCompute();
5921193323Sed
5922193323Sed
5923193323Sed  // Okay at this point we know that all elements of the chrec are constants and
5924193323Sed  // that the start element is zero.
5925193323Sed
5926193323Sed  // First check to see if the range contains zero.  If not, the first
5927193323Sed  // iteration exits.
5928193323Sed  unsigned BitWidth = SE.getTypeSizeInBits(getType());
5929193323Sed  if (!Range.contains(APInt(BitWidth, 0)))
5930207618Srdivacky    return SE.getConstant(getType(), 0);
5931193323Sed
5932193323Sed  if (isAffine()) {
5933193323Sed    // If this is an affine expression then we have this situation:
5934193323Sed    //   Solve {0,+,A} in Range  ===  Ax in Range
5935193323Sed
5936193323Sed    // We know that zero is in the range.  If A is positive then we know that
5937193323Sed    // the upper value of the range must be the first possible exit value.
5938193323Sed    // If A is negative then the lower of the range is the last possible loop
5939193323Sed    // value.  Also note that we already checked for a full range.
5940193323Sed    APInt One(BitWidth,1);
5941193323Sed    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
5942193323Sed    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
5943193323Sed
5944193323Sed    // The exit value should be (End+A)/A.
5945193323Sed    APInt ExitVal = (End + A).udiv(A);
5946198090Srdivacky    ConstantInt *ExitValue = ConstantInt::get(SE.getContext(), ExitVal);
5947193323Sed
5948193323Sed    // Evaluate at the exit value.  If we really did fall out of the valid
5949193323Sed    // range, then we computed our trip count, otherwise wrap around or other
5950193323Sed    // things must have happened.
5951193323Sed    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
5952193323Sed    if (Range.contains(Val->getValue()))
5953193323Sed      return SE.getCouldNotCompute();  // Something strange happened
5954193323Sed
5955193323Sed    // Ensure that the previous value is in the range.  This is a sanity check.
5956193323Sed    assert(Range.contains(
5957195098Sed           EvaluateConstantChrecAtConstant(this,
5958198090Srdivacky           ConstantInt::get(SE.getContext(), ExitVal - One), SE)->getValue()) &&
5959193323Sed           "Linear scev computation is off in a bad way!");
5960193323Sed    return SE.getConstant(ExitValue);
5961193323Sed  } else if (isQuadratic()) {
5962193323Sed    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
5963193323Sed    // quadratic equation to solve it.  To do this, we must frame our problem in
5964193323Sed    // terms of figuring out when zero is crossed, instead of when
5965193323Sed    // Range.getUpper() is crossed.
5966198090Srdivacky    SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
5967193323Sed    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
5968221345Sdim    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(),
5969221345Sdim                                             // getNoWrapFlags(FlagNW)
5970221345Sdim                                             FlagAnyWrap);
5971193323Sed
5972193323Sed    // Next, solve the constructed addrec
5973198090Srdivacky    std::pair<const SCEV *,const SCEV *> Roots =
5974193323Sed      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
5975193323Sed    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
5976193323Sed    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
5977193323Sed    if (R1) {
5978193323Sed      // Pick the smallest positive root value.
5979193323Sed      if (ConstantInt *CB =
5980195098Sed          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
5981198090Srdivacky                         R1->getValue(), R2->getValue()))) {
5982193323Sed        if (CB->getZExtValue() == false)
5983193323Sed          std::swap(R1, R2);   // R1 is the minimum root now.
5984193323Sed
5985193323Sed        // Make sure the root is not off by one.  The returned iteration should
5986193323Sed        // not be in the range, but the previous one should be.  When solving
5987193323Sed        // for "X*X < 5", for example, we should not return a root of 2.
5988193323Sed        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
5989193323Sed                                                             R1->getValue(),
5990193323Sed                                                             SE);
5991193323Sed        if (Range.contains(R1Val->getValue())) {
5992193323Sed          // The next iteration must be out of the range...
5993198090Srdivacky          ConstantInt *NextVal =
5994198090Srdivacky                ConstantInt::get(SE.getContext(), R1->getValue()->getValue()+1);
5995193323Sed
5996193323Sed          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
5997193323Sed          if (!Range.contains(R1Val->getValue()))
5998193323Sed            return SE.getConstant(NextVal);
5999193323Sed          return SE.getCouldNotCompute();  // Something strange happened
6000193323Sed        }
6001193323Sed
6002193323Sed        // If R1 was not in the range, then it is a good return value.  Make
6003193323Sed        // sure that R1-1 WAS in the range though, just in case.
6004198090Srdivacky        ConstantInt *NextVal =
6005198090Srdivacky               ConstantInt::get(SE.getContext(), R1->getValue()->getValue()-1);
6006193323Sed        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
6007193323Sed        if (Range.contains(R1Val->getValue()))
6008193323Sed          return R1;
6009193323Sed        return SE.getCouldNotCompute();  // Something strange happened
6010193323Sed      }
6011193323Sed    }
6012193323Sed  }
6013193323Sed
6014193323Sed  return SE.getCouldNotCompute();
6015193323Sed}
6016193323Sed
6017193323Sed
6018193323Sed
6019193323Sed//===----------------------------------------------------------------------===//
6020193323Sed//                   SCEVCallbackVH Class Implementation
6021193323Sed//===----------------------------------------------------------------------===//
6022193323Sed
6023193323Sedvoid ScalarEvolution::SCEVCallbackVH::deleted() {
6024198090Srdivacky  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6025193323Sed  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
6026193323Sed    SE->ConstantEvolutionLoopExitValue.erase(PN);
6027212904Sdim  SE->ValueExprMap.erase(getValPtr());
6028193323Sed  // this now dangles!
6029193323Sed}
6030193323Sed
6031212904Sdimvoid ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
6032198090Srdivacky  assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!");
6033193323Sed
6034193323Sed  // Forget all the expressions associated with users of the old value,
6035193323Sed  // so that future queries will recompute the expressions using the new
6036193323Sed  // value.
6037212904Sdim  Value *Old = getValPtr();
6038193323Sed  SmallVector<User *, 16> Worklist;
6039198090Srdivacky  SmallPtrSet<User *, 8> Visited;
6040193323Sed  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
6041193323Sed       UI != UE; ++UI)
6042193323Sed    Worklist.push_back(*UI);
6043193323Sed  while (!Worklist.empty()) {
6044193323Sed    User *U = Worklist.pop_back_val();
6045193323Sed    // Deleting the Old value will cause this to dangle. Postpone
6046193323Sed    // that until everything else is done.
6047212904Sdim    if (U == Old)
6048193323Sed      continue;
6049198090Srdivacky    if (!Visited.insert(U))
6050198090Srdivacky      continue;
6051193323Sed    if (PHINode *PN = dyn_cast<PHINode>(U))
6052193323Sed      SE->ConstantEvolutionLoopExitValue.erase(PN);
6053212904Sdim    SE->ValueExprMap.erase(U);
6054198090Srdivacky    for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
6055198090Srdivacky         UI != UE; ++UI)
6056198090Srdivacky      Worklist.push_back(*UI);
6057193323Sed  }
6058212904Sdim  // Delete the Old value.
6059212904Sdim  if (PHINode *PN = dyn_cast<PHINode>(Old))
6060212904Sdim    SE->ConstantEvolutionLoopExitValue.erase(PN);
6061212904Sdim  SE->ValueExprMap.erase(Old);
6062212904Sdim  // this now dangles!
6063193323Sed}
6064193323Sed
6065193323SedScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
6066193323Sed  : CallbackVH(V), SE(se) {}
6067193323Sed
6068193323Sed//===----------------------------------------------------------------------===//
6069193323Sed//                   ScalarEvolution Class Implementation
6070193323Sed//===----------------------------------------------------------------------===//
6071193323Sed
6072193323SedScalarEvolution::ScalarEvolution()
6073212904Sdim  : FunctionPass(ID), FirstUnknown(0) {
6074218893Sdim  initializeScalarEvolutionPass(*PassRegistry::getPassRegistry());
6075193323Sed}
6076193323Sed
6077193323Sedbool ScalarEvolution::runOnFunction(Function &F) {
6078193323Sed  this->F = &F;
6079193323Sed  LI = &getAnalysis<LoopInfo>();
6080204642Srdivacky  TD = getAnalysisIfAvailable<TargetData>();
6081202878Srdivacky  DT = &getAnalysis<DominatorTree>();
6082193323Sed  return false;
6083193323Sed}
6084193323Sed
6085193323Sedvoid ScalarEvolution::releaseMemory() {
6086212904Sdim  // Iterate through all the SCEVUnknown instances and call their
6087212904Sdim  // destructors, so that they release their references to their values.
6088212904Sdim  for (SCEVUnknown *U = FirstUnknown; U; U = U->Next)
6089212904Sdim    U->~SCEVUnknown();
6090212904Sdim  FirstUnknown = 0;
6091212904Sdim
6092212904Sdim  ValueExprMap.clear();
6093193323Sed  BackedgeTakenCounts.clear();
6094193323Sed  ConstantEvolutionLoopExitValue.clear();
6095193323Sed  ValuesAtScopes.clear();
6096218893Sdim  LoopDispositions.clear();
6097218893Sdim  BlockDispositions.clear();
6098218893Sdim  UnsignedRanges.clear();
6099218893Sdim  SignedRanges.clear();
6100195340Sed  UniqueSCEVs.clear();
6101195340Sed  SCEVAllocator.Reset();
6102193323Sed}
6103193323Sed
6104193323Sedvoid ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
6105193323Sed  AU.setPreservesAll();
6106193323Sed  AU.addRequiredTransitive<LoopInfo>();
6107202878Srdivacky  AU.addRequiredTransitive<DominatorTree>();
6108193323Sed}
6109193323Sed
6110193323Sedbool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
6111193323Sed  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
6112193323Sed}
6113193323Sed
6114193323Sedstatic void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
6115193323Sed                          const Loop *L) {
6116193323Sed  // Print all inner loops first
6117193323Sed  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
6118193323Sed    PrintLoopInfo(OS, SE, *I);
6119193323Sed
6120202375Srdivacky  OS << "Loop ";
6121202375Srdivacky  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6122202375Srdivacky  OS << ": ";
6123193323Sed
6124201360Srdivacky  SmallVector<BasicBlock *, 8> ExitBlocks;
6125193323Sed  L->getExitBlocks(ExitBlocks);
6126193323Sed  if (ExitBlocks.size() != 1)
6127193323Sed    OS << "<multiple exits> ";
6128193323Sed
6129193323Sed  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
6130193323Sed    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
6131193323Sed  } else {
6132193323Sed    OS << "Unpredictable backedge-taken count. ";
6133193323Sed  }
6134193323Sed
6135202375Srdivacky  OS << "\n"
6136202375Srdivacky        "Loop ";
6137202375Srdivacky  WriteAsOperand(OS, L->getHeader(), /*PrintType=*/false);
6138202375Srdivacky  OS << ": ";
6139195098Sed
6140195098Sed  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
6141195098Sed    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
6142195098Sed  } else {
6143195098Sed    OS << "Unpredictable max backedge-taken count. ";
6144195098Sed  }
6145195098Sed
6146195098Sed  OS << "\n";
6147193323Sed}
6148193323Sed
6149201360Srdivackyvoid ScalarEvolution::print(raw_ostream &OS, const Module *) const {
6150204642Srdivacky  // ScalarEvolution's implementation of the print method is to print
6151193323Sed  // out SCEV values of all instructions that are interesting. Doing
6152193323Sed  // this potentially causes it to create new SCEV objects though,
6153193323Sed  // which technically conflicts with the const qualifier. This isn't
6154198090Srdivacky  // observable from outside the class though, so casting away the
6155198090Srdivacky  // const isn't dangerous.
6156201360Srdivacky  ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);
6157193323Sed
6158202375Srdivacky  OS << "Classifying expressions for: ";
6159202375Srdivacky  WriteAsOperand(OS, F, /*PrintType=*/false);
6160202375Srdivacky  OS << "\n";
6161193323Sed  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
6162207618Srdivacky    if (isSCEVable(I->getType()) && !isa<CmpInst>(*I)) {
6163198090Srdivacky      OS << *I << '\n';
6164193323Sed      OS << "  -->  ";
6165198090Srdivacky      const SCEV *SV = SE.getSCEV(&*I);
6166193323Sed      SV->print(OS);
6167193323Sed
6168194612Sed      const Loop *L = LI->getLoopFor((*I).getParent());
6169194612Sed
6170198090Srdivacky      const SCEV *AtUse = SE.getSCEVAtScope(SV, L);
6171194612Sed      if (AtUse != SV) {
6172194612Sed        OS << "  -->  ";
6173194612Sed        AtUse->print(OS);
6174194612Sed      }
6175194612Sed
6176194612Sed      if (L) {
6177194612Sed        OS << "\t\t" "Exits: ";
6178198090Srdivacky        const SCEV *ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
6179218893Sdim        if (!SE.isLoopInvariant(ExitValue, L)) {
6180193323Sed          OS << "<<Unknown>>";
6181193323Sed        } else {
6182193323Sed          OS << *ExitValue;
6183193323Sed        }
6184193323Sed      }
6185193323Sed
6186193323Sed      OS << "\n";
6187193323Sed    }
6188193323Sed
6189202375Srdivacky  OS << "Determining loop execution counts for: ";
6190202375Srdivacky  WriteAsOperand(OS, F, /*PrintType=*/false);
6191202375Srdivacky  OS << "\n";
6192193323Sed  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
6193193323Sed    PrintLoopInfo(OS, &SE, *I);
6194193323Sed}
6195193323Sed
6196218893SdimScalarEvolution::LoopDisposition
6197218893SdimScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
6198218893Sdim  std::map<const Loop *, LoopDisposition> &Values = LoopDispositions[S];
6199218893Sdim  std::pair<std::map<const Loop *, LoopDisposition>::iterator, bool> Pair =
6200218893Sdim    Values.insert(std::make_pair(L, LoopVariant));
6201218893Sdim  if (!Pair.second)
6202218893Sdim    return Pair.first->second;
6203218893Sdim
6204218893Sdim  LoopDisposition D = computeLoopDisposition(S, L);
6205218893Sdim  return LoopDispositions[S][L] = D;
6206218893Sdim}
6207218893Sdim
6208218893SdimScalarEvolution::LoopDisposition
6209218893SdimScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
6210218893Sdim  switch (S->getSCEVType()) {
6211218893Sdim  case scConstant:
6212218893Sdim    return LoopInvariant;
6213218893Sdim  case scTruncate:
6214218893Sdim  case scZeroExtend:
6215218893Sdim  case scSignExtend:
6216218893Sdim    return getLoopDisposition(cast<SCEVCastExpr>(S)->getOperand(), L);
6217218893Sdim  case scAddRecExpr: {
6218218893Sdim    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6219218893Sdim
6220218893Sdim    // If L is the addrec's loop, it's computable.
6221218893Sdim    if (AR->getLoop() == L)
6222218893Sdim      return LoopComputable;
6223218893Sdim
6224218893Sdim    // Add recurrences are never invariant in the function-body (null loop).
6225218893Sdim    if (!L)
6226218893Sdim      return LoopVariant;
6227218893Sdim
6228218893Sdim    // This recurrence is variant w.r.t. L if L contains AR's loop.
6229218893Sdim    if (L->contains(AR->getLoop()))
6230218893Sdim      return LoopVariant;
6231218893Sdim
6232218893Sdim    // This recurrence is invariant w.r.t. L if AR's loop contains L.
6233218893Sdim    if (AR->getLoop()->contains(L))
6234218893Sdim      return LoopInvariant;
6235218893Sdim
6236218893Sdim    // This recurrence is variant w.r.t. L if any of its operands
6237218893Sdim    // are variant.
6238218893Sdim    for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
6239218893Sdim         I != E; ++I)
6240218893Sdim      if (!isLoopInvariant(*I, L))
6241218893Sdim        return LoopVariant;
6242218893Sdim
6243218893Sdim    // Otherwise it's loop-invariant.
6244218893Sdim    return LoopInvariant;
6245218893Sdim  }
6246218893Sdim  case scAddExpr:
6247218893Sdim  case scMulExpr:
6248218893Sdim  case scUMaxExpr:
6249218893Sdim  case scSMaxExpr: {
6250218893Sdim    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6251218893Sdim    bool HasVarying = false;
6252218893Sdim    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6253218893Sdim         I != E; ++I) {
6254218893Sdim      LoopDisposition D = getLoopDisposition(*I, L);
6255218893Sdim      if (D == LoopVariant)
6256218893Sdim        return LoopVariant;
6257218893Sdim      if (D == LoopComputable)
6258218893Sdim        HasVarying = true;
6259218893Sdim    }
6260218893Sdim    return HasVarying ? LoopComputable : LoopInvariant;
6261218893Sdim  }
6262218893Sdim  case scUDivExpr: {
6263218893Sdim    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6264218893Sdim    LoopDisposition LD = getLoopDisposition(UDiv->getLHS(), L);
6265218893Sdim    if (LD == LoopVariant)
6266218893Sdim      return LoopVariant;
6267218893Sdim    LoopDisposition RD = getLoopDisposition(UDiv->getRHS(), L);
6268218893Sdim    if (RD == LoopVariant)
6269218893Sdim      return LoopVariant;
6270218893Sdim    return (LD == LoopInvariant && RD == LoopInvariant) ?
6271218893Sdim           LoopInvariant : LoopComputable;
6272218893Sdim  }
6273218893Sdim  case scUnknown:
6274218893Sdim    // All non-instruction values are loop invariant.  All instructions are loop
6275218893Sdim    // invariant if they are not contained in the specified loop.
6276218893Sdim    // Instructions are never considered invariant in the function body
6277218893Sdim    // (null loop) because they are defined within the "loop".
6278218893Sdim    if (Instruction *I = dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue()))
6279218893Sdim      return (L && !L->contains(I)) ? LoopInvariant : LoopVariant;
6280218893Sdim    return LoopInvariant;
6281218893Sdim  case scCouldNotCompute:
6282218893Sdim    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6283218893Sdim    return LoopVariant;
6284218893Sdim  default: break;
6285218893Sdim  }
6286218893Sdim  llvm_unreachable("Unknown SCEV kind!");
6287218893Sdim  return LoopVariant;
6288218893Sdim}
6289218893Sdim
6290218893Sdimbool ScalarEvolution::isLoopInvariant(const SCEV *S, const Loop *L) {
6291218893Sdim  return getLoopDisposition(S, L) == LoopInvariant;
6292218893Sdim}
6293218893Sdim
6294218893Sdimbool ScalarEvolution::hasComputableLoopEvolution(const SCEV *S, const Loop *L) {
6295218893Sdim  return getLoopDisposition(S, L) == LoopComputable;
6296218893Sdim}
6297218893Sdim
6298218893SdimScalarEvolution::BlockDisposition
6299218893SdimScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6300218893Sdim  std::map<const BasicBlock *, BlockDisposition> &Values = BlockDispositions[S];
6301218893Sdim  std::pair<std::map<const BasicBlock *, BlockDisposition>::iterator, bool>
6302218893Sdim    Pair = Values.insert(std::make_pair(BB, DoesNotDominateBlock));
6303218893Sdim  if (!Pair.second)
6304218893Sdim    return Pair.first->second;
6305218893Sdim
6306218893Sdim  BlockDisposition D = computeBlockDisposition(S, BB);
6307218893Sdim  return BlockDispositions[S][BB] = D;
6308218893Sdim}
6309218893Sdim
6310218893SdimScalarEvolution::BlockDisposition
6311218893SdimScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
6312218893Sdim  switch (S->getSCEVType()) {
6313218893Sdim  case scConstant:
6314218893Sdim    return ProperlyDominatesBlock;
6315218893Sdim  case scTruncate:
6316218893Sdim  case scZeroExtend:
6317218893Sdim  case scSignExtend:
6318218893Sdim    return getBlockDisposition(cast<SCEVCastExpr>(S)->getOperand(), BB);
6319218893Sdim  case scAddRecExpr: {
6320218893Sdim    // This uses a "dominates" query instead of "properly dominates" query
6321218893Sdim    // to test for proper dominance too, because the instruction which
6322218893Sdim    // produces the addrec's value is a PHI, and a PHI effectively properly
6323218893Sdim    // dominates its entire containing block.
6324218893Sdim    const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);
6325218893Sdim    if (!DT->dominates(AR->getLoop()->getHeader(), BB))
6326218893Sdim      return DoesNotDominateBlock;
6327218893Sdim  }
6328218893Sdim  // FALL THROUGH into SCEVNAryExpr handling.
6329218893Sdim  case scAddExpr:
6330218893Sdim  case scMulExpr:
6331218893Sdim  case scUMaxExpr:
6332218893Sdim  case scSMaxExpr: {
6333218893Sdim    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6334218893Sdim    bool Proper = true;
6335218893Sdim    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6336218893Sdim         I != E; ++I) {
6337218893Sdim      BlockDisposition D = getBlockDisposition(*I, BB);
6338218893Sdim      if (D == DoesNotDominateBlock)
6339218893Sdim        return DoesNotDominateBlock;
6340218893Sdim      if (D == DominatesBlock)
6341218893Sdim        Proper = false;
6342218893Sdim    }
6343218893Sdim    return Proper ? ProperlyDominatesBlock : DominatesBlock;
6344218893Sdim  }
6345218893Sdim  case scUDivExpr: {
6346218893Sdim    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6347218893Sdim    const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6348218893Sdim    BlockDisposition LD = getBlockDisposition(LHS, BB);
6349218893Sdim    if (LD == DoesNotDominateBlock)
6350218893Sdim      return DoesNotDominateBlock;
6351218893Sdim    BlockDisposition RD = getBlockDisposition(RHS, BB);
6352218893Sdim    if (RD == DoesNotDominateBlock)
6353218893Sdim      return DoesNotDominateBlock;
6354218893Sdim    return (LD == ProperlyDominatesBlock && RD == ProperlyDominatesBlock) ?
6355218893Sdim      ProperlyDominatesBlock : DominatesBlock;
6356218893Sdim  }
6357218893Sdim  case scUnknown:
6358218893Sdim    if (Instruction *I =
6359218893Sdim          dyn_cast<Instruction>(cast<SCEVUnknown>(S)->getValue())) {
6360218893Sdim      if (I->getParent() == BB)
6361218893Sdim        return DominatesBlock;
6362218893Sdim      if (DT->properlyDominates(I->getParent(), BB))
6363218893Sdim        return ProperlyDominatesBlock;
6364218893Sdim      return DoesNotDominateBlock;
6365218893Sdim    }
6366218893Sdim    return ProperlyDominatesBlock;
6367218893Sdim  case scCouldNotCompute:
6368218893Sdim    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6369218893Sdim    return DoesNotDominateBlock;
6370218893Sdim  default: break;
6371218893Sdim  }
6372218893Sdim  llvm_unreachable("Unknown SCEV kind!");
6373218893Sdim  return DoesNotDominateBlock;
6374218893Sdim}
6375218893Sdim
6376218893Sdimbool ScalarEvolution::dominates(const SCEV *S, const BasicBlock *BB) {
6377218893Sdim  return getBlockDisposition(S, BB) >= DominatesBlock;
6378218893Sdim}
6379218893Sdim
6380218893Sdimbool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {
6381218893Sdim  return getBlockDisposition(S, BB) == ProperlyDominatesBlock;
6382218893Sdim}
6383218893Sdim
6384218893Sdimbool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {
6385218893Sdim  switch (S->getSCEVType()) {
6386218893Sdim  case scConstant:
6387218893Sdim    return false;
6388218893Sdim  case scTruncate:
6389218893Sdim  case scZeroExtend:
6390218893Sdim  case scSignExtend: {
6391218893Sdim    const SCEVCastExpr *Cast = cast<SCEVCastExpr>(S);
6392218893Sdim    const SCEV *CastOp = Cast->getOperand();
6393218893Sdim    return Op == CastOp || hasOperand(CastOp, Op);
6394218893Sdim  }
6395218893Sdim  case scAddRecExpr:
6396218893Sdim  case scAddExpr:
6397218893Sdim  case scMulExpr:
6398218893Sdim  case scUMaxExpr:
6399218893Sdim  case scSMaxExpr: {
6400218893Sdim    const SCEVNAryExpr *NAry = cast<SCEVNAryExpr>(S);
6401218893Sdim    for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
6402218893Sdim         I != E; ++I) {
6403218893Sdim      const SCEV *NAryOp = *I;
6404218893Sdim      if (NAryOp == Op || hasOperand(NAryOp, Op))
6405218893Sdim        return true;
6406218893Sdim    }
6407218893Sdim    return false;
6408218893Sdim  }
6409218893Sdim  case scUDivExpr: {
6410218893Sdim    const SCEVUDivExpr *UDiv = cast<SCEVUDivExpr>(S);
6411218893Sdim    const SCEV *LHS = UDiv->getLHS(), *RHS = UDiv->getRHS();
6412218893Sdim    return LHS == Op || hasOperand(LHS, Op) ||
6413218893Sdim           RHS == Op || hasOperand(RHS, Op);
6414218893Sdim  }
6415218893Sdim  case scUnknown:
6416218893Sdim    return false;
6417218893Sdim  case scCouldNotCompute:
6418218893Sdim    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
6419218893Sdim    return false;
6420218893Sdim  default: break;
6421218893Sdim  }
6422218893Sdim  llvm_unreachable("Unknown SCEV kind!");
6423218893Sdim  return false;
6424218893Sdim}
6425218893Sdim
6426218893Sdimvoid ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
6427218893Sdim  ValuesAtScopes.erase(S);
6428218893Sdim  LoopDispositions.erase(S);
6429218893Sdim  BlockDispositions.erase(S);
6430218893Sdim  UnsignedRanges.erase(S);
6431218893Sdim  SignedRanges.erase(S);
6432218893Sdim}
6433