ScalarEvolution.cpp revision 195340
1//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===//
2//
3//                     The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file contains the implementation of the scalar evolution analysis
11// engine, which is used primarily to analyze expressions involving induction
12// variables in loops.
13//
14// There are several aspects to this library.  First is the representation of
15// scalar expressions, which are represented as subclasses of the SCEV class.
16// These classes are used to represent certain types of subexpressions that we
17// can handle.  These classes are reference counted, managed by the const SCEV*
18// class.  We only create one SCEV of a particular shape, so pointer-comparisons
19// for equality are legal.
20//
21// One important aspect of the SCEV objects is that they are never cyclic, even
22// if there is a cycle in the dataflow for an expression (ie, a PHI node).  If
23// the PHI node is one of the idioms that we can represent (e.g., a polynomial
24// recurrence) then we represent it directly as a recurrence node, otherwise we
25// represent it as a SCEVUnknown node.
26//
27// In addition to being able to represent expressions of various types, we also
28// have folders that are used to build the *canonical* representation for a
29// particular expression.  These folders are capable of using a variety of
30// rewrite rules to simplify the expressions.
31//
32// Once the folders are defined, we can implement the more interesting
33// higher-level code, such as the code that recognizes PHI nodes of various
34// types, computes the execution count of a loop, etc.
35//
36// TODO: We should use these routines and value representations to implement
37// dependence analysis!
38//
39//===----------------------------------------------------------------------===//
40//
41// There are several good references for the techniques used in this analysis.
42//
43//  Chains of recurrences -- a method to expedite the evaluation
44//  of closed-form functions
45//  Olaf Bachmann, Paul S. Wang, Eugene V. Zima
46//
47//  On computational properties of chains of recurrences
48//  Eugene V. Zima
49//
50//  Symbolic Evaluation of Chains of Recurrences for Loop Optimization
51//  Robert A. van Engelen
52//
53//  Efficient Symbolic Analysis for Optimizing Compilers
54//  Robert A. van Engelen
55//
56//  Using the chains of recurrences algebra for data dependence testing and
57//  induction variable substitution
58//  MS Thesis, Johnie Birch
59//
60//===----------------------------------------------------------------------===//
61
62#define DEBUG_TYPE "scalar-evolution"
63#include "llvm/Analysis/ScalarEvolutionExpressions.h"
64#include "llvm/Constants.h"
65#include "llvm/DerivedTypes.h"
66#include "llvm/GlobalVariable.h"
67#include "llvm/Instructions.h"
68#include "llvm/Analysis/ConstantFolding.h"
69#include "llvm/Analysis/Dominators.h"
70#include "llvm/Analysis/LoopInfo.h"
71#include "llvm/Analysis/ValueTracking.h"
72#include "llvm/Assembly/Writer.h"
73#include "llvm/Target/TargetData.h"
74#include "llvm/Support/CommandLine.h"
75#include "llvm/Support/Compiler.h"
76#include "llvm/Support/ConstantRange.h"
77#include "llvm/Support/GetElementPtrTypeIterator.h"
78#include "llvm/Support/InstIterator.h"
79#include "llvm/Support/MathExtras.h"
80#include "llvm/Support/raw_ostream.h"
81#include "llvm/ADT/Statistic.h"
82#include "llvm/ADT/STLExtras.h"
83#include <algorithm>
84using namespace llvm;
85
86STATISTIC(NumArrayLenItCounts,
87          "Number of trip counts computed with array length");
88STATISTIC(NumTripCountsComputed,
89          "Number of loops with predictable loop counts");
90STATISTIC(NumTripCountsNotComputed,
91          "Number of loops without predictable loop counts");
92STATISTIC(NumBruteForceTripCountsComputed,
93          "Number of loops with trip counts computed by force");
94
95static cl::opt<unsigned>
96MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,
97                        cl::desc("Maximum number of iterations SCEV will "
98                                 "symbolically execute a constant "
99                                 "derived loop"),
100                        cl::init(100));
101
102static RegisterPass<ScalarEvolution>
103R("scalar-evolution", "Scalar Evolution Analysis", false, true);
104char ScalarEvolution::ID = 0;
105
106//===----------------------------------------------------------------------===//
107//                           SCEV class definitions
108//===----------------------------------------------------------------------===//
109
110//===----------------------------------------------------------------------===//
111// Implementation of the SCEV class.
112//
113
114SCEV::~SCEV() {}
115
116void SCEV::dump() const {
117  print(errs());
118  errs() << '\n';
119}
120
121void SCEV::print(std::ostream &o) const {
122  raw_os_ostream OS(o);
123  print(OS);
124}
125
126bool SCEV::isZero() const {
127  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
128    return SC->getValue()->isZero();
129  return false;
130}
131
132bool SCEV::isOne() const {
133  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
134    return SC->getValue()->isOne();
135  return false;
136}
137
138bool SCEV::isAllOnesValue() const {
139  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(this))
140    return SC->getValue()->isAllOnesValue();
141  return false;
142}
143
144SCEVCouldNotCompute::SCEVCouldNotCompute() :
145  SCEV(scCouldNotCompute) {}
146
147void SCEVCouldNotCompute::Profile(FoldingSetNodeID &ID) const {
148  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
149}
150
151bool SCEVCouldNotCompute::isLoopInvariant(const Loop *L) const {
152  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
153  return false;
154}
155
156const Type *SCEVCouldNotCompute::getType() const {
157  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
158  return 0;
159}
160
161bool SCEVCouldNotCompute::hasComputableLoopEvolution(const Loop *L) const {
162  assert(0 && "Attempt to use a SCEVCouldNotCompute object!");
163  return false;
164}
165
166const SCEV *
167SCEVCouldNotCompute::replaceSymbolicValuesWithConcrete(
168                                                    const SCEV *Sym,
169                                                    const SCEV *Conc,
170                                                    ScalarEvolution &SE) const {
171  return this;
172}
173
174void SCEVCouldNotCompute::print(raw_ostream &OS) const {
175  OS << "***COULDNOTCOMPUTE***";
176}
177
178bool SCEVCouldNotCompute::classof(const SCEV *S) {
179  return S->getSCEVType() == scCouldNotCompute;
180}
181
182const SCEV* ScalarEvolution::getConstant(ConstantInt *V) {
183  FoldingSetNodeID ID;
184  ID.AddInteger(scConstant);
185  ID.AddPointer(V);
186  void *IP = 0;
187  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
188  SCEV *S = SCEVAllocator.Allocate<SCEVConstant>();
189  new (S) SCEVConstant(V);
190  UniqueSCEVs.InsertNode(S, IP);
191  return S;
192}
193
194const SCEV* ScalarEvolution::getConstant(const APInt& Val) {
195  return getConstant(ConstantInt::get(Val));
196}
197
198const SCEV*
199ScalarEvolution::getConstant(const Type *Ty, uint64_t V, bool isSigned) {
200  return getConstant(ConstantInt::get(cast<IntegerType>(Ty), V, isSigned));
201}
202
203void SCEVConstant::Profile(FoldingSetNodeID &ID) const {
204  ID.AddInteger(scConstant);
205  ID.AddPointer(V);
206}
207
208const Type *SCEVConstant::getType() const { return V->getType(); }
209
210void SCEVConstant::print(raw_ostream &OS) const {
211  WriteAsOperand(OS, V, false);
212}
213
214SCEVCastExpr::SCEVCastExpr(unsigned SCEVTy,
215                           const SCEV* op, const Type *ty)
216  : SCEV(SCEVTy), Op(op), Ty(ty) {}
217
218void SCEVCastExpr::Profile(FoldingSetNodeID &ID) const {
219  ID.AddInteger(getSCEVType());
220  ID.AddPointer(Op);
221  ID.AddPointer(Ty);
222}
223
224bool SCEVCastExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
225  return Op->dominates(BB, DT);
226}
227
228SCEVTruncateExpr::SCEVTruncateExpr(const SCEV* op, const Type *ty)
229  : SCEVCastExpr(scTruncate, op, ty) {
230  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
231         (Ty->isInteger() || isa<PointerType>(Ty)) &&
232         "Cannot truncate non-integer value!");
233}
234
235void SCEVTruncateExpr::print(raw_ostream &OS) const {
236  OS << "(trunc " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
237}
238
239SCEVZeroExtendExpr::SCEVZeroExtendExpr(const SCEV* op, const Type *ty)
240  : SCEVCastExpr(scZeroExtend, op, ty) {
241  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
242         (Ty->isInteger() || isa<PointerType>(Ty)) &&
243         "Cannot zero extend non-integer value!");
244}
245
246void SCEVZeroExtendExpr::print(raw_ostream &OS) const {
247  OS << "(zext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
248}
249
250SCEVSignExtendExpr::SCEVSignExtendExpr(const SCEV* op, const Type *ty)
251  : SCEVCastExpr(scSignExtend, op, ty) {
252  assert((Op->getType()->isInteger() || isa<PointerType>(Op->getType())) &&
253         (Ty->isInteger() || isa<PointerType>(Ty)) &&
254         "Cannot sign extend non-integer value!");
255}
256
257void SCEVSignExtendExpr::print(raw_ostream &OS) const {
258  OS << "(sext " << *Op->getType() << " " << *Op << " to " << *Ty << ")";
259}
260
261void SCEVCommutativeExpr::print(raw_ostream &OS) const {
262  assert(Operands.size() > 1 && "This plus expr shouldn't exist!");
263  const char *OpStr = getOperationStr();
264  OS << "(" << *Operands[0];
265  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
266    OS << OpStr << *Operands[i];
267  OS << ")";
268}
269
270const SCEV *
271SCEVCommutativeExpr::replaceSymbolicValuesWithConcrete(
272                                                    const SCEV *Sym,
273                                                    const SCEV *Conc,
274                                                    ScalarEvolution &SE) const {
275  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
276    const SCEV* H =
277      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
278    if (H != getOperand(i)) {
279      SmallVector<const SCEV*, 8> NewOps;
280      NewOps.reserve(getNumOperands());
281      for (unsigned j = 0; j != i; ++j)
282        NewOps.push_back(getOperand(j));
283      NewOps.push_back(H);
284      for (++i; i != e; ++i)
285        NewOps.push_back(getOperand(i)->
286                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
287
288      if (isa<SCEVAddExpr>(this))
289        return SE.getAddExpr(NewOps);
290      else if (isa<SCEVMulExpr>(this))
291        return SE.getMulExpr(NewOps);
292      else if (isa<SCEVSMaxExpr>(this))
293        return SE.getSMaxExpr(NewOps);
294      else if (isa<SCEVUMaxExpr>(this))
295        return SE.getUMaxExpr(NewOps);
296      else
297        assert(0 && "Unknown commutative expr!");
298    }
299  }
300  return this;
301}
302
303void SCEVNAryExpr::Profile(FoldingSetNodeID &ID) const {
304  ID.AddInteger(getSCEVType());
305  ID.AddInteger(Operands.size());
306  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
307    ID.AddPointer(Operands[i]);
308}
309
310bool SCEVNAryExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
311  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
312    if (!getOperand(i)->dominates(BB, DT))
313      return false;
314  }
315  return true;
316}
317
318void SCEVUDivExpr::Profile(FoldingSetNodeID &ID) const {
319  ID.AddInteger(scUDivExpr);
320  ID.AddPointer(LHS);
321  ID.AddPointer(RHS);
322}
323
324bool SCEVUDivExpr::dominates(BasicBlock *BB, DominatorTree *DT) const {
325  return LHS->dominates(BB, DT) && RHS->dominates(BB, DT);
326}
327
328void SCEVUDivExpr::print(raw_ostream &OS) const {
329  OS << "(" << *LHS << " /u " << *RHS << ")";
330}
331
332const Type *SCEVUDivExpr::getType() const {
333  // In most cases the types of LHS and RHS will be the same, but in some
334  // crazy cases one or the other may be a pointer. ScalarEvolution doesn't
335  // depend on the type for correctness, but handling types carefully can
336  // avoid extra casts in the SCEVExpander. The LHS is more likely to be
337  // a pointer type than the RHS, so use the RHS' type here.
338  return RHS->getType();
339}
340
341void SCEVAddRecExpr::Profile(FoldingSetNodeID &ID) const {
342  ID.AddInteger(scAddRecExpr);
343  ID.AddInteger(Operands.size());
344  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
345    ID.AddPointer(Operands[i]);
346  ID.AddPointer(L);
347}
348
349const SCEV *
350SCEVAddRecExpr::replaceSymbolicValuesWithConcrete(const SCEV *Sym,
351                                                  const SCEV *Conc,
352                                                  ScalarEvolution &SE) const {
353  for (unsigned i = 0, e = getNumOperands(); i != e; ++i) {
354    const SCEV* H =
355      getOperand(i)->replaceSymbolicValuesWithConcrete(Sym, Conc, SE);
356    if (H != getOperand(i)) {
357      SmallVector<const SCEV*, 8> NewOps;
358      NewOps.reserve(getNumOperands());
359      for (unsigned j = 0; j != i; ++j)
360        NewOps.push_back(getOperand(j));
361      NewOps.push_back(H);
362      for (++i; i != e; ++i)
363        NewOps.push_back(getOperand(i)->
364                         replaceSymbolicValuesWithConcrete(Sym, Conc, SE));
365
366      return SE.getAddRecExpr(NewOps, L);
367    }
368  }
369  return this;
370}
371
372
373bool SCEVAddRecExpr::isLoopInvariant(const Loop *QueryLoop) const {
374  // Add recurrences are never invariant in the function-body (null loop).
375  if (!QueryLoop)
376    return false;
377
378  // This recurrence is variant w.r.t. QueryLoop if QueryLoop contains L.
379  if (QueryLoop->contains(L->getHeader()))
380    return false;
381
382  // This recurrence is variant w.r.t. QueryLoop if any of its operands
383  // are variant.
384  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
385    if (!getOperand(i)->isLoopInvariant(QueryLoop))
386      return false;
387
388  // Otherwise it's loop-invariant.
389  return true;
390}
391
392void SCEVAddRecExpr::print(raw_ostream &OS) const {
393  OS << "{" << *Operands[0];
394  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
395    OS << ",+," << *Operands[i];
396  OS << "}<" << L->getHeader()->getName() + ">";
397}
398
399void SCEVUnknown::Profile(FoldingSetNodeID &ID) const {
400  ID.AddInteger(scUnknown);
401  ID.AddPointer(V);
402}
403
404bool SCEVUnknown::isLoopInvariant(const Loop *L) const {
405  // All non-instruction values are loop invariant.  All instructions are loop
406  // invariant if they are not contained in the specified loop.
407  // Instructions are never considered invariant in the function body
408  // (null loop) because they are defined within the "loop".
409  if (Instruction *I = dyn_cast<Instruction>(V))
410    return L && !L->contains(I->getParent());
411  return true;
412}
413
414bool SCEVUnknown::dominates(BasicBlock *BB, DominatorTree *DT) const {
415  if (Instruction *I = dyn_cast<Instruction>(getValue()))
416    return DT->dominates(I->getParent(), BB);
417  return true;
418}
419
420const Type *SCEVUnknown::getType() const {
421  return V->getType();
422}
423
424void SCEVUnknown::print(raw_ostream &OS) const {
425  WriteAsOperand(OS, V, false);
426}
427
428//===----------------------------------------------------------------------===//
429//                               SCEV Utilities
430//===----------------------------------------------------------------------===//
431
432namespace {
433  /// SCEVComplexityCompare - Return true if the complexity of the LHS is less
434  /// than the complexity of the RHS.  This comparator is used to canonicalize
435  /// expressions.
436  class VISIBILITY_HIDDEN SCEVComplexityCompare {
437    LoopInfo *LI;
438  public:
439    explicit SCEVComplexityCompare(LoopInfo *li) : LI(li) {}
440
441    bool operator()(const SCEV *LHS, const SCEV *RHS) const {
442      // Primarily, sort the SCEVs by their getSCEVType().
443      if (LHS->getSCEVType() != RHS->getSCEVType())
444        return LHS->getSCEVType() < RHS->getSCEVType();
445
446      // Aside from the getSCEVType() ordering, the particular ordering
447      // isn't very important except that it's beneficial to be consistent,
448      // so that (a + b) and (b + a) don't end up as different expressions.
449
450      // Sort SCEVUnknown values with some loose heuristics. TODO: This is
451      // not as complete as it could be.
452      if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) {
453        const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
454
455        // Order pointer values after integer values. This helps SCEVExpander
456        // form GEPs.
457        if (isa<PointerType>(LU->getType()) && !isa<PointerType>(RU->getType()))
458          return false;
459        if (isa<PointerType>(RU->getType()) && !isa<PointerType>(LU->getType()))
460          return true;
461
462        // Compare getValueID values.
463        if (LU->getValue()->getValueID() != RU->getValue()->getValueID())
464          return LU->getValue()->getValueID() < RU->getValue()->getValueID();
465
466        // Sort arguments by their position.
467        if (const Argument *LA = dyn_cast<Argument>(LU->getValue())) {
468          const Argument *RA = cast<Argument>(RU->getValue());
469          return LA->getArgNo() < RA->getArgNo();
470        }
471
472        // For instructions, compare their loop depth, and their opcode.
473        // This is pretty loose.
474        if (Instruction *LV = dyn_cast<Instruction>(LU->getValue())) {
475          Instruction *RV = cast<Instruction>(RU->getValue());
476
477          // Compare loop depths.
478          if (LI->getLoopDepth(LV->getParent()) !=
479              LI->getLoopDepth(RV->getParent()))
480            return LI->getLoopDepth(LV->getParent()) <
481                   LI->getLoopDepth(RV->getParent());
482
483          // Compare opcodes.
484          if (LV->getOpcode() != RV->getOpcode())
485            return LV->getOpcode() < RV->getOpcode();
486
487          // Compare the number of operands.
488          if (LV->getNumOperands() != RV->getNumOperands())
489            return LV->getNumOperands() < RV->getNumOperands();
490        }
491
492        return false;
493      }
494
495      // Compare constant values.
496      if (const SCEVConstant *LC = dyn_cast<SCEVConstant>(LHS)) {
497        const SCEVConstant *RC = cast<SCEVConstant>(RHS);
498        return LC->getValue()->getValue().ult(RC->getValue()->getValue());
499      }
500
501      // Compare addrec loop depths.
502      if (const SCEVAddRecExpr *LA = dyn_cast<SCEVAddRecExpr>(LHS)) {
503        const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS);
504        if (LA->getLoop()->getLoopDepth() != RA->getLoop()->getLoopDepth())
505          return LA->getLoop()->getLoopDepth() < RA->getLoop()->getLoopDepth();
506      }
507
508      // Lexicographically compare n-ary expressions.
509      if (const SCEVNAryExpr *LC = dyn_cast<SCEVNAryExpr>(LHS)) {
510        const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS);
511        for (unsigned i = 0, e = LC->getNumOperands(); i != e; ++i) {
512          if (i >= RC->getNumOperands())
513            return false;
514          if (operator()(LC->getOperand(i), RC->getOperand(i)))
515            return true;
516          if (operator()(RC->getOperand(i), LC->getOperand(i)))
517            return false;
518        }
519        return LC->getNumOperands() < RC->getNumOperands();
520      }
521
522      // Lexicographically compare udiv expressions.
523      if (const SCEVUDivExpr *LC = dyn_cast<SCEVUDivExpr>(LHS)) {
524        const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);
525        if (operator()(LC->getLHS(), RC->getLHS()))
526          return true;
527        if (operator()(RC->getLHS(), LC->getLHS()))
528          return false;
529        if (operator()(LC->getRHS(), RC->getRHS()))
530          return true;
531        if (operator()(RC->getRHS(), LC->getRHS()))
532          return false;
533        return false;
534      }
535
536      // Compare cast expressions by operand.
537      if (const SCEVCastExpr *LC = dyn_cast<SCEVCastExpr>(LHS)) {
538        const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);
539        return operator()(LC->getOperand(), RC->getOperand());
540      }
541
542      assert(0 && "Unknown SCEV kind!");
543      return false;
544    }
545  };
546}
547
548/// GroupByComplexity - Given a list of SCEV objects, order them by their
549/// complexity, and group objects of the same complexity together by value.
550/// When this routine is finished, we know that any duplicates in the vector are
551/// consecutive and that complexity is monotonically increasing.
552///
553/// Note that we go take special precautions to ensure that we get determinstic
554/// results from this routine.  In other words, we don't want the results of
555/// this to depend on where the addresses of various SCEV objects happened to
556/// land in memory.
557///
558static void GroupByComplexity(SmallVectorImpl<const SCEV*> &Ops,
559                              LoopInfo *LI) {
560  if (Ops.size() < 2) return;  // Noop
561  if (Ops.size() == 2) {
562    // This is the common case, which also happens to be trivially simple.
563    // Special case it.
564    if (SCEVComplexityCompare(LI)(Ops[1], Ops[0]))
565      std::swap(Ops[0], Ops[1]);
566    return;
567  }
568
569  // Do the rough sort by complexity.
570  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI));
571
572  // Now that we are sorted by complexity, group elements of the same
573  // complexity.  Note that this is, at worst, N^2, but the vector is likely to
574  // be extremely short in practice.  Note that we take this approach because we
575  // do not want to depend on the addresses of the objects we are grouping.
576  for (unsigned i = 0, e = Ops.size(); i != e-2; ++i) {
577    const SCEV *S = Ops[i];
578    unsigned Complexity = S->getSCEVType();
579
580    // If there are any objects of the same complexity and same value as this
581    // one, group them.
582    for (unsigned j = i+1; j != e && Ops[j]->getSCEVType() == Complexity; ++j) {
583      if (Ops[j] == S) { // Found a duplicate.
584        // Move it to immediately after i'th element.
585        std::swap(Ops[i+1], Ops[j]);
586        ++i;   // no need to rescan it.
587        if (i == e-2) return;  // Done!
588      }
589    }
590  }
591}
592
593
594
595//===----------------------------------------------------------------------===//
596//                      Simple SCEV method implementations
597//===----------------------------------------------------------------------===//
598
599/// BinomialCoefficient - Compute BC(It, K).  The result has width W.
600/// Assume, K > 0.
601static const SCEV* BinomialCoefficient(const SCEV* It, unsigned K,
602                                      ScalarEvolution &SE,
603                                      const Type* ResultTy) {
604  // Handle the simplest case efficiently.
605  if (K == 1)
606    return SE.getTruncateOrZeroExtend(It, ResultTy);
607
608  // We are using the following formula for BC(It, K):
609  //
610  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / K!
611  //
612  // Suppose, W is the bitwidth of the return value.  We must be prepared for
613  // overflow.  Hence, we must assure that the result of our computation is
614  // equal to the accurate one modulo 2^W.  Unfortunately, division isn't
615  // safe in modular arithmetic.
616  //
617  // However, this code doesn't use exactly that formula; the formula it uses
618  // is something like the following, where T is the number of factors of 2 in
619  // K! (i.e. trailing zeros in the binary representation of K!), and ^ is
620  // exponentiation:
621  //
622  //   BC(It, K) = (It * (It - 1) * ... * (It - K + 1)) / 2^T / (K! / 2^T)
623  //
624  // This formula is trivially equivalent to the previous formula.  However,
625  // this formula can be implemented much more efficiently.  The trick is that
626  // K! / 2^T is odd, and exact division by an odd number *is* safe in modular
627  // arithmetic.  To do exact division in modular arithmetic, all we have
628  // to do is multiply by the inverse.  Therefore, this step can be done at
629  // width W.
630  //
631  // The next issue is how to safely do the division by 2^T.  The way this
632  // is done is by doing the multiplication step at a width of at least W + T
633  // bits.  This way, the bottom W+T bits of the product are accurate. Then,
634  // when we perform the division by 2^T (which is equivalent to a right shift
635  // by T), the bottom W bits are accurate.  Extra bits are okay; they'll get
636  // truncated out after the division by 2^T.
637  //
638  // In comparison to just directly using the first formula, this technique
639  // is much more efficient; using the first formula requires W * K bits,
640  // but this formula less than W + K bits. Also, the first formula requires
641  // a division step, whereas this formula only requires multiplies and shifts.
642  //
643  // It doesn't matter whether the subtraction step is done in the calculation
644  // width or the input iteration count's width; if the subtraction overflows,
645  // the result must be zero anyway.  We prefer here to do it in the width of
646  // the induction variable because it helps a lot for certain cases; CodeGen
647  // isn't smart enough to ignore the overflow, which leads to much less
648  // efficient code if the width of the subtraction is wider than the native
649  // register width.
650  //
651  // (It's possible to not widen at all by pulling out factors of 2 before
652  // the multiplication; for example, K=2 can be calculated as
653  // It/2*(It+(It*INT_MIN/INT_MIN)+-1). However, it requires
654  // extra arithmetic, so it's not an obvious win, and it gets
655  // much more complicated for K > 3.)
656
657  // Protection from insane SCEVs; this bound is conservative,
658  // but it probably doesn't matter.
659  if (K > 1000)
660    return SE.getCouldNotCompute();
661
662  unsigned W = SE.getTypeSizeInBits(ResultTy);
663
664  // Calculate K! / 2^T and T; we divide out the factors of two before
665  // multiplying for calculating K! / 2^T to avoid overflow.
666  // Other overflow doesn't matter because we only care about the bottom
667  // W bits of the result.
668  APInt OddFactorial(W, 1);
669  unsigned T = 1;
670  for (unsigned i = 3; i <= K; ++i) {
671    APInt Mult(W, i);
672    unsigned TwoFactors = Mult.countTrailingZeros();
673    T += TwoFactors;
674    Mult = Mult.lshr(TwoFactors);
675    OddFactorial *= Mult;
676  }
677
678  // We need at least W + T bits for the multiplication step
679  unsigned CalculationBits = W + T;
680
681  // Calcuate 2^T, at width T+W.
682  APInt DivFactor = APInt(CalculationBits, 1).shl(T);
683
684  // Calculate the multiplicative inverse of K! / 2^T;
685  // this multiplication factor will perform the exact division by
686  // K! / 2^T.
687  APInt Mod = APInt::getSignedMinValue(W+1);
688  APInt MultiplyFactor = OddFactorial.zext(W+1);
689  MultiplyFactor = MultiplyFactor.multiplicativeInverse(Mod);
690  MultiplyFactor = MultiplyFactor.trunc(W);
691
692  // Calculate the product, at width T+W
693  const IntegerType *CalculationTy = IntegerType::get(CalculationBits);
694  const SCEV* Dividend = SE.getTruncateOrZeroExtend(It, CalculationTy);
695  for (unsigned i = 1; i != K; ++i) {
696    const SCEV* S = SE.getMinusSCEV(It, SE.getIntegerSCEV(i, It->getType()));
697    Dividend = SE.getMulExpr(Dividend,
698                             SE.getTruncateOrZeroExtend(S, CalculationTy));
699  }
700
701  // Divide by 2^T
702  const SCEV* DivResult = SE.getUDivExpr(Dividend, SE.getConstant(DivFactor));
703
704  // Truncate the result, and divide by K! / 2^T.
705
706  return SE.getMulExpr(SE.getConstant(MultiplyFactor),
707                       SE.getTruncateOrZeroExtend(DivResult, ResultTy));
708}
709
710/// evaluateAtIteration - Return the value of this chain of recurrences at
711/// the specified iteration number.  We can evaluate this recurrence by
712/// multiplying each element in the chain by the binomial coefficient
713/// corresponding to it.  In other words, we can evaluate {A,+,B,+,C,+,D} as:
714///
715///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)
716///
717/// where BC(It, k) stands for binomial coefficient.
718///
719const SCEV* SCEVAddRecExpr::evaluateAtIteration(const SCEV* It,
720                                               ScalarEvolution &SE) const {
721  const SCEV* Result = getStart();
722  for (unsigned i = 1, e = getNumOperands(); i != e; ++i) {
723    // The computation is correct in the face of overflow provided that the
724    // multiplication is performed _after_ the evaluation of the binomial
725    // coefficient.
726    const SCEV* Coeff = BinomialCoefficient(It, i, SE, getType());
727    if (isa<SCEVCouldNotCompute>(Coeff))
728      return Coeff;
729
730    Result = SE.getAddExpr(Result, SE.getMulExpr(getOperand(i), Coeff));
731  }
732  return Result;
733}
734
735//===----------------------------------------------------------------------===//
736//                    SCEV Expression folder implementations
737//===----------------------------------------------------------------------===//
738
739const SCEV* ScalarEvolution::getTruncateExpr(const SCEV* Op,
740                                            const Type *Ty) {
741  assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
742         "This is not a truncating conversion!");
743  assert(isSCEVable(Ty) &&
744         "This is not a conversion to a SCEVable type!");
745  Ty = getEffectiveSCEVType(Ty);
746
747  // Fold if the operand is constant.
748  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
749    return getConstant(
750      cast<ConstantInt>(ConstantExpr::getTrunc(SC->getValue(), Ty)));
751
752  // trunc(trunc(x)) --> trunc(x)
753  if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op))
754    return getTruncateExpr(ST->getOperand(), Ty);
755
756  // trunc(sext(x)) --> sext(x) if widening or trunc(x) if narrowing
757  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
758    return getTruncateOrSignExtend(SS->getOperand(), Ty);
759
760  // trunc(zext(x)) --> zext(x) if widening or trunc(x) if narrowing
761  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
762    return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
763
764  // If the input value is a chrec scev, truncate the chrec's operands.
765  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) {
766    SmallVector<const SCEV*, 4> Operands;
767    for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
768      Operands.push_back(getTruncateExpr(AddRec->getOperand(i), Ty));
769    return getAddRecExpr(Operands, AddRec->getLoop());
770  }
771
772  FoldingSetNodeID ID;
773  ID.AddInteger(scTruncate);
774  ID.AddPointer(Op);
775  ID.AddPointer(Ty);
776  void *IP = 0;
777  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
778  SCEV *S = SCEVAllocator.Allocate<SCEVTruncateExpr>();
779  new (S) SCEVTruncateExpr(Op, Ty);
780  UniqueSCEVs.InsertNode(S, IP);
781  return S;
782}
783
784const SCEV* ScalarEvolution::getZeroExtendExpr(const SCEV* Op,
785                                              const Type *Ty) {
786  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
787         "This is not an extending conversion!");
788  assert(isSCEVable(Ty) &&
789         "This is not a conversion to a SCEVable type!");
790  Ty = getEffectiveSCEVType(Ty);
791
792  // Fold if the operand is constant.
793  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
794    const Type *IntTy = getEffectiveSCEVType(Ty);
795    Constant *C = ConstantExpr::getZExt(SC->getValue(), IntTy);
796    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
797    return getConstant(cast<ConstantInt>(C));
798  }
799
800  // zext(zext(x)) --> zext(x)
801  if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
802    return getZeroExtendExpr(SZ->getOperand(), Ty);
803
804  // If the input value is a chrec scev, and we can prove that the value
805  // did not overflow the old, smaller, value, we can zero extend all of the
806  // operands (often constants).  This allows analysis of something like
807  // this:  for (unsigned char X = 0; X < 100; ++X) { int Y = X; }
808  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
809    if (AR->isAffine()) {
810      // Check whether the backedge-taken count is SCEVCouldNotCompute.
811      // Note that this serves two purposes: It filters out loops that are
812      // simply not analyzable, and it covers the case where this code is
813      // being called from within backedge-taken count analysis, such that
814      // attempting to ask for the backedge-taken count would likely result
815      // in infinite recursion. In the later case, the analysis code will
816      // cope with a conservative value, and it will take care to purge
817      // that value once it has finished.
818      const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
819      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
820        // Manually compute the final value for AR, checking for
821        // overflow.
822        const SCEV* Start = AR->getStart();
823        const SCEV* Step = AR->getStepRecurrence(*this);
824
825        // Check whether the backedge-taken count can be losslessly casted to
826        // the addrec's type. The count is always unsigned.
827        const SCEV* CastedMaxBECount =
828          getTruncateOrZeroExtend(MaxBECount, Start->getType());
829        const SCEV* RecastedMaxBECount =
830          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
831        if (MaxBECount == RecastedMaxBECount) {
832          const Type *WideTy =
833            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
834          // Check whether Start+Step*MaxBECount has no unsigned overflow.
835          const SCEV* ZMul =
836            getMulExpr(CastedMaxBECount,
837                       getTruncateOrZeroExtend(Step, Start->getType()));
838          const SCEV* Add = getAddExpr(Start, ZMul);
839          const SCEV* OperandExtendedAdd =
840            getAddExpr(getZeroExtendExpr(Start, WideTy),
841                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
842                                  getZeroExtendExpr(Step, WideTy)));
843          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
844            // Return the expression with the addrec on the outside.
845            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
846                                 getZeroExtendExpr(Step, Ty),
847                                 AR->getLoop());
848
849          // Similar to above, only this time treat the step value as signed.
850          // This covers loops that count down.
851          const SCEV* SMul =
852            getMulExpr(CastedMaxBECount,
853                       getTruncateOrSignExtend(Step, Start->getType()));
854          Add = getAddExpr(Start, SMul);
855          OperandExtendedAdd =
856            getAddExpr(getZeroExtendExpr(Start, WideTy),
857                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
858                                  getSignExtendExpr(Step, WideTy)));
859          if (getZeroExtendExpr(Add, WideTy) == OperandExtendedAdd)
860            // Return the expression with the addrec on the outside.
861            return getAddRecExpr(getZeroExtendExpr(Start, Ty),
862                                 getSignExtendExpr(Step, Ty),
863                                 AR->getLoop());
864        }
865      }
866    }
867
868  FoldingSetNodeID ID;
869  ID.AddInteger(scZeroExtend);
870  ID.AddPointer(Op);
871  ID.AddPointer(Ty);
872  void *IP = 0;
873  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
874  SCEV *S = SCEVAllocator.Allocate<SCEVZeroExtendExpr>();
875  new (S) SCEVZeroExtendExpr(Op, Ty);
876  UniqueSCEVs.InsertNode(S, IP);
877  return S;
878}
879
880const SCEV* ScalarEvolution::getSignExtendExpr(const SCEV* Op,
881                                              const Type *Ty) {
882  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
883         "This is not an extending conversion!");
884  assert(isSCEVable(Ty) &&
885         "This is not a conversion to a SCEVable type!");
886  Ty = getEffectiveSCEVType(Ty);
887
888  // Fold if the operand is constant.
889  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op)) {
890    const Type *IntTy = getEffectiveSCEVType(Ty);
891    Constant *C = ConstantExpr::getSExt(SC->getValue(), IntTy);
892    if (IntTy != Ty) C = ConstantExpr::getIntToPtr(C, Ty);
893    return getConstant(cast<ConstantInt>(C));
894  }
895
896  // sext(sext(x)) --> sext(x)
897  if (const SCEVSignExtendExpr *SS = dyn_cast<SCEVSignExtendExpr>(Op))
898    return getSignExtendExpr(SS->getOperand(), Ty);
899
900  // If the input value is a chrec scev, and we can prove that the value
901  // did not overflow the old, smaller, value, we can sign extend all of the
902  // operands (often constants).  This allows analysis of something like
903  // this:  for (signed char X = 0; X < 100; ++X) { int Y = X; }
904  if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op))
905    if (AR->isAffine()) {
906      // Check whether the backedge-taken count is SCEVCouldNotCompute.
907      // Note that this serves two purposes: It filters out loops that are
908      // simply not analyzable, and it covers the case where this code is
909      // being called from within backedge-taken count analysis, such that
910      // attempting to ask for the backedge-taken count would likely result
911      // in infinite recursion. In the later case, the analysis code will
912      // cope with a conservative value, and it will take care to purge
913      // that value once it has finished.
914      const SCEV* MaxBECount = getMaxBackedgeTakenCount(AR->getLoop());
915      if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
916        // Manually compute the final value for AR, checking for
917        // overflow.
918        const SCEV* Start = AR->getStart();
919        const SCEV* Step = AR->getStepRecurrence(*this);
920
921        // Check whether the backedge-taken count can be losslessly casted to
922        // the addrec's type. The count is always unsigned.
923        const SCEV* CastedMaxBECount =
924          getTruncateOrZeroExtend(MaxBECount, Start->getType());
925        const SCEV* RecastedMaxBECount =
926          getTruncateOrZeroExtend(CastedMaxBECount, MaxBECount->getType());
927        if (MaxBECount == RecastedMaxBECount) {
928          const Type *WideTy =
929            IntegerType::get(getTypeSizeInBits(Start->getType()) * 2);
930          // Check whether Start+Step*MaxBECount has no signed overflow.
931          const SCEV* SMul =
932            getMulExpr(CastedMaxBECount,
933                       getTruncateOrSignExtend(Step, Start->getType()));
934          const SCEV* Add = getAddExpr(Start, SMul);
935          const SCEV* OperandExtendedAdd =
936            getAddExpr(getSignExtendExpr(Start, WideTy),
937                       getMulExpr(getZeroExtendExpr(CastedMaxBECount, WideTy),
938                                  getSignExtendExpr(Step, WideTy)));
939          if (getSignExtendExpr(Add, WideTy) == OperandExtendedAdd)
940            // Return the expression with the addrec on the outside.
941            return getAddRecExpr(getSignExtendExpr(Start, Ty),
942                                 getSignExtendExpr(Step, Ty),
943                                 AR->getLoop());
944        }
945      }
946    }
947
948  FoldingSetNodeID ID;
949  ID.AddInteger(scSignExtend);
950  ID.AddPointer(Op);
951  ID.AddPointer(Ty);
952  void *IP = 0;
953  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
954  SCEV *S = SCEVAllocator.Allocate<SCEVSignExtendExpr>();
955  new (S) SCEVSignExtendExpr(Op, Ty);
956  UniqueSCEVs.InsertNode(S, IP);
957  return S;
958}
959
960/// getAnyExtendExpr - Return a SCEV for the given operand extended with
961/// unspecified bits out to the given type.
962///
963const SCEV* ScalarEvolution::getAnyExtendExpr(const SCEV* Op,
964                                             const Type *Ty) {
965  assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
966         "This is not an extending conversion!");
967  assert(isSCEVable(Ty) &&
968         "This is not a conversion to a SCEVable type!");
969  Ty = getEffectiveSCEVType(Ty);
970
971  // Sign-extend negative constants.
972  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(Op))
973    if (SC->getValue()->getValue().isNegative())
974      return getSignExtendExpr(Op, Ty);
975
976  // Peel off a truncate cast.
977  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Op)) {
978    const SCEV* NewOp = T->getOperand();
979    if (getTypeSizeInBits(NewOp->getType()) < getTypeSizeInBits(Ty))
980      return getAnyExtendExpr(NewOp, Ty);
981    return getTruncateOrNoop(NewOp, Ty);
982  }
983
984  // Next try a zext cast. If the cast is folded, use it.
985  const SCEV* ZExt = getZeroExtendExpr(Op, Ty);
986  if (!isa<SCEVZeroExtendExpr>(ZExt))
987    return ZExt;
988
989  // Next try a sext cast. If the cast is folded, use it.
990  const SCEV* SExt = getSignExtendExpr(Op, Ty);
991  if (!isa<SCEVSignExtendExpr>(SExt))
992    return SExt;
993
994  // If the expression is obviously signed, use the sext cast value.
995  if (isa<SCEVSMaxExpr>(Op))
996    return SExt;
997
998  // Absent any other information, use the zext cast value.
999  return ZExt;
1000}
1001
1002/// CollectAddOperandsWithScales - Process the given Ops list, which is
1003/// a list of operands to be added under the given scale, update the given
1004/// map. This is a helper function for getAddRecExpr. As an example of
1005/// what it does, given a sequence of operands that would form an add
1006/// expression like this:
1007///
1008///    m + n + 13 + (A * (o + p + (B * q + m + 29))) + r + (-1 * r)
1009///
1010/// where A and B are constants, update the map with these values:
1011///
1012///    (m, 1+A*B), (n, 1), (o, A), (p, A), (q, A*B), (r, 0)
1013///
1014/// and add 13 + A*B*29 to AccumulatedConstant.
1015/// This will allow getAddRecExpr to produce this:
1016///
1017///    13+A*B*29 + n + (m * (1+A*B)) + ((o + p) * A) + (q * A*B)
1018///
1019/// This form often exposes folding opportunities that are hidden in
1020/// the original operand list.
1021///
1022/// Return true iff it appears that any interesting folding opportunities
1023/// may be exposed. This helps getAddRecExpr short-circuit extra work in
1024/// the common case where no interesting opportunities are present, and
1025/// is also used as a check to avoid infinite recursion.
1026///
1027static bool
1028CollectAddOperandsWithScales(DenseMap<const SCEV*, APInt> &M,
1029                             SmallVector<const SCEV*, 8> &NewOps,
1030                             APInt &AccumulatedConstant,
1031                             const SmallVectorImpl<const SCEV*> &Ops,
1032                             const APInt &Scale,
1033                             ScalarEvolution &SE) {
1034  bool Interesting = false;
1035
1036  // Iterate over the add operands.
1037  for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1038    const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[i]);
1039    if (Mul && isa<SCEVConstant>(Mul->getOperand(0))) {
1040      APInt NewScale =
1041        Scale * cast<SCEVConstant>(Mul->getOperand(0))->getValue()->getValue();
1042      if (Mul->getNumOperands() == 2 && isa<SCEVAddExpr>(Mul->getOperand(1))) {
1043        // A multiplication of a constant with another add; recurse.
1044        Interesting |=
1045          CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1046                                       cast<SCEVAddExpr>(Mul->getOperand(1))
1047                                         ->getOperands(),
1048                                       NewScale, SE);
1049      } else {
1050        // A multiplication of a constant with some other value. Update
1051        // the map.
1052        SmallVector<const SCEV*, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
1053        const SCEV* Key = SE.getMulExpr(MulOps);
1054        std::pair<DenseMap<const SCEV*, APInt>::iterator, bool> Pair =
1055          M.insert(std::make_pair(Key, NewScale));
1056        if (Pair.second) {
1057          NewOps.push_back(Pair.first->first);
1058        } else {
1059          Pair.first->second += NewScale;
1060          // The map already had an entry for this value, which may indicate
1061          // a folding opportunity.
1062          Interesting = true;
1063        }
1064      }
1065    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1066      // Pull a buried constant out to the outside.
1067      if (Scale != 1 || AccumulatedConstant != 0 || C->isZero())
1068        Interesting = true;
1069      AccumulatedConstant += Scale * C->getValue()->getValue();
1070    } else {
1071      // An ordinary operand. Update the map.
1072      std::pair<DenseMap<const SCEV*, APInt>::iterator, bool> Pair =
1073        M.insert(std::make_pair(Ops[i], Scale));
1074      if (Pair.second) {
1075        NewOps.push_back(Pair.first->first);
1076      } else {
1077        Pair.first->second += Scale;
1078        // The map already had an entry for this value, which may indicate
1079        // a folding opportunity.
1080        Interesting = true;
1081      }
1082    }
1083  }
1084
1085  return Interesting;
1086}
1087
1088namespace {
1089  struct APIntCompare {
1090    bool operator()(const APInt &LHS, const APInt &RHS) const {
1091      return LHS.ult(RHS);
1092    }
1093  };
1094}
1095
1096/// getAddExpr - Get a canonical add expression, or something simpler if
1097/// possible.
1098const SCEV* ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV*> &Ops) {
1099  assert(!Ops.empty() && "Cannot get empty add!");
1100  if (Ops.size() == 1) return Ops[0];
1101#ifndef NDEBUG
1102  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1103    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1104           getEffectiveSCEVType(Ops[0]->getType()) &&
1105           "SCEVAddExpr operand types don't match!");
1106#endif
1107
1108  // Sort by complexity, this groups all similar expression types together.
1109  GroupByComplexity(Ops, LI);
1110
1111  // If there are any constants, fold them together.
1112  unsigned Idx = 0;
1113  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1114    ++Idx;
1115    assert(Idx < Ops.size());
1116    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1117      // We found two constants, fold them together!
1118      Ops[0] = getConstant(LHSC->getValue()->getValue() +
1119                           RHSC->getValue()->getValue());
1120      if (Ops.size() == 2) return Ops[0];
1121      Ops.erase(Ops.begin()+1);  // Erase the folded element
1122      LHSC = cast<SCEVConstant>(Ops[0]);
1123    }
1124
1125    // If we are left with a constant zero being added, strip it off.
1126    if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1127      Ops.erase(Ops.begin());
1128      --Idx;
1129    }
1130  }
1131
1132  if (Ops.size() == 1) return Ops[0];
1133
1134  // Okay, check to see if the same value occurs in the operand list twice.  If
1135  // so, merge them together into an multiply expression.  Since we sorted the
1136  // list, these values are required to be adjacent.
1137  const Type *Ty = Ops[0]->getType();
1138  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1139    if (Ops[i] == Ops[i+1]) {      //  X + Y + Y  -->  X + Y*2
1140      // Found a match, merge the two values into a multiply, and add any
1141      // remaining values to the result.
1142      const SCEV* Two = getIntegerSCEV(2, Ty);
1143      const SCEV* Mul = getMulExpr(Ops[i], Two);
1144      if (Ops.size() == 2)
1145        return Mul;
1146      Ops.erase(Ops.begin()+i, Ops.begin()+i+2);
1147      Ops.push_back(Mul);
1148      return getAddExpr(Ops);
1149    }
1150
1151  // Check for truncates. If all the operands are truncated from the same
1152  // type, see if factoring out the truncate would permit the result to be
1153  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
1154  // if the contents of the resulting outer trunc fold to something simple.
1155  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
1156    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
1157    const Type *DstType = Trunc->getType();
1158    const Type *SrcType = Trunc->getOperand()->getType();
1159    SmallVector<const SCEV*, 8> LargeOps;
1160    bool Ok = true;
1161    // Check all the operands to see if they can be represented in the
1162    // source type of the truncate.
1163    for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
1164      if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(Ops[i])) {
1165        if (T->getOperand()->getType() != SrcType) {
1166          Ok = false;
1167          break;
1168        }
1169        LargeOps.push_back(T->getOperand());
1170      } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(Ops[i])) {
1171        // This could be either sign or zero extension, but sign extension
1172        // is much more likely to be foldable here.
1173        LargeOps.push_back(getSignExtendExpr(C, SrcType));
1174      } else if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(Ops[i])) {
1175        SmallVector<const SCEV*, 8> LargeMulOps;
1176        for (unsigned j = 0, f = M->getNumOperands(); j != f && Ok; ++j) {
1177          if (const SCEVTruncateExpr *T =
1178                dyn_cast<SCEVTruncateExpr>(M->getOperand(j))) {
1179            if (T->getOperand()->getType() != SrcType) {
1180              Ok = false;
1181              break;
1182            }
1183            LargeMulOps.push_back(T->getOperand());
1184          } else if (const SCEVConstant *C =
1185                       dyn_cast<SCEVConstant>(M->getOperand(j))) {
1186            // This could be either sign or zero extension, but sign extension
1187            // is much more likely to be foldable here.
1188            LargeMulOps.push_back(getSignExtendExpr(C, SrcType));
1189          } else {
1190            Ok = false;
1191            break;
1192          }
1193        }
1194        if (Ok)
1195          LargeOps.push_back(getMulExpr(LargeMulOps));
1196      } else {
1197        Ok = false;
1198        break;
1199      }
1200    }
1201    if (Ok) {
1202      // Evaluate the expression in the larger type.
1203      const SCEV* Fold = getAddExpr(LargeOps);
1204      // If it folds to something simple, use it. Otherwise, don't.
1205      if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
1206        return getTruncateExpr(Fold, DstType);
1207    }
1208  }
1209
1210  // Skip past any other cast SCEVs.
1211  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddExpr)
1212    ++Idx;
1213
1214  // If there are add operands they would be next.
1215  if (Idx < Ops.size()) {
1216    bool DeletedAdd = false;
1217    while (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[Idx])) {
1218      // If we have an add, expand the add operands onto the end of the operands
1219      // list.
1220      Ops.insert(Ops.end(), Add->op_begin(), Add->op_end());
1221      Ops.erase(Ops.begin()+Idx);
1222      DeletedAdd = true;
1223    }
1224
1225    // If we deleted at least one add, we added operands to the end of the list,
1226    // and they are not necessarily sorted.  Recurse to resort and resimplify
1227    // any operands we just aquired.
1228    if (DeletedAdd)
1229      return getAddExpr(Ops);
1230  }
1231
1232  // Skip over the add expression until we get to a multiply.
1233  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1234    ++Idx;
1235
1236  // Check to see if there are any folding opportunities present with
1237  // operands multiplied by constant values.
1238  if (Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx])) {
1239    uint64_t BitWidth = getTypeSizeInBits(Ty);
1240    DenseMap<const SCEV*, APInt> M;
1241    SmallVector<const SCEV*, 8> NewOps;
1242    APInt AccumulatedConstant(BitWidth, 0);
1243    if (CollectAddOperandsWithScales(M, NewOps, AccumulatedConstant,
1244                                     Ops, APInt(BitWidth, 1), *this)) {
1245      // Some interesting folding opportunity is present, so its worthwhile to
1246      // re-generate the operands list. Group the operands by constant scale,
1247      // to avoid multiplying by the same constant scale multiple times.
1248      std::map<APInt, SmallVector<const SCEV*, 4>, APIntCompare> MulOpLists;
1249      for (SmallVector<const SCEV*, 8>::iterator I = NewOps.begin(),
1250           E = NewOps.end(); I != E; ++I)
1251        MulOpLists[M.find(*I)->second].push_back(*I);
1252      // Re-generate the operands list.
1253      Ops.clear();
1254      if (AccumulatedConstant != 0)
1255        Ops.push_back(getConstant(AccumulatedConstant));
1256      for (std::map<APInt, SmallVector<const SCEV *, 4>, APIntCompare>::iterator
1257           I = MulOpLists.begin(), E = MulOpLists.end(); I != E; ++I)
1258        if (I->first != 0)
1259          Ops.push_back(getMulExpr(getConstant(I->first),
1260                                   getAddExpr(I->second)));
1261      if (Ops.empty())
1262        return getIntegerSCEV(0, Ty);
1263      if (Ops.size() == 1)
1264        return Ops[0];
1265      return getAddExpr(Ops);
1266    }
1267  }
1268
1269  // If we are adding something to a multiply expression, make sure the
1270  // something is not already an operand of the multiply.  If so, merge it into
1271  // the multiply.
1272  for (; Idx < Ops.size() && isa<SCEVMulExpr>(Ops[Idx]); ++Idx) {
1273    const SCEVMulExpr *Mul = cast<SCEVMulExpr>(Ops[Idx]);
1274    for (unsigned MulOp = 0, e = Mul->getNumOperands(); MulOp != e; ++MulOp) {
1275      const SCEV *MulOpSCEV = Mul->getOperand(MulOp);
1276      for (unsigned AddOp = 0, e = Ops.size(); AddOp != e; ++AddOp)
1277        if (MulOpSCEV == Ops[AddOp] && !isa<SCEVConstant>(Ops[AddOp])) {
1278          // Fold W + X + (X * Y * Z)  -->  W + (X * ((Y*Z)+1))
1279          const SCEV* InnerMul = Mul->getOperand(MulOp == 0);
1280          if (Mul->getNumOperands() != 2) {
1281            // If the multiply has more than two operands, we must get the
1282            // Y*Z term.
1283            SmallVector<const SCEV*, 4> MulOps(Mul->op_begin(), Mul->op_end());
1284            MulOps.erase(MulOps.begin()+MulOp);
1285            InnerMul = getMulExpr(MulOps);
1286          }
1287          const SCEV* One = getIntegerSCEV(1, Ty);
1288          const SCEV* AddOne = getAddExpr(InnerMul, One);
1289          const SCEV* OuterMul = getMulExpr(AddOne, Ops[AddOp]);
1290          if (Ops.size() == 2) return OuterMul;
1291          if (AddOp < Idx) {
1292            Ops.erase(Ops.begin()+AddOp);
1293            Ops.erase(Ops.begin()+Idx-1);
1294          } else {
1295            Ops.erase(Ops.begin()+Idx);
1296            Ops.erase(Ops.begin()+AddOp-1);
1297          }
1298          Ops.push_back(OuterMul);
1299          return getAddExpr(Ops);
1300        }
1301
1302      // Check this multiply against other multiplies being added together.
1303      for (unsigned OtherMulIdx = Idx+1;
1304           OtherMulIdx < Ops.size() && isa<SCEVMulExpr>(Ops[OtherMulIdx]);
1305           ++OtherMulIdx) {
1306        const SCEVMulExpr *OtherMul = cast<SCEVMulExpr>(Ops[OtherMulIdx]);
1307        // If MulOp occurs in OtherMul, we can fold the two multiplies
1308        // together.
1309        for (unsigned OMulOp = 0, e = OtherMul->getNumOperands();
1310             OMulOp != e; ++OMulOp)
1311          if (OtherMul->getOperand(OMulOp) == MulOpSCEV) {
1312            // Fold X + (A*B*C) + (A*D*E) --> X + (A*(B*C+D*E))
1313            const SCEV* InnerMul1 = Mul->getOperand(MulOp == 0);
1314            if (Mul->getNumOperands() != 2) {
1315              SmallVector<const SCEV *, 4> MulOps(Mul->op_begin(),
1316                                                  Mul->op_end());
1317              MulOps.erase(MulOps.begin()+MulOp);
1318              InnerMul1 = getMulExpr(MulOps);
1319            }
1320            const SCEV* InnerMul2 = OtherMul->getOperand(OMulOp == 0);
1321            if (OtherMul->getNumOperands() != 2) {
1322              SmallVector<const SCEV *, 4> MulOps(OtherMul->op_begin(),
1323                                                  OtherMul->op_end());
1324              MulOps.erase(MulOps.begin()+OMulOp);
1325              InnerMul2 = getMulExpr(MulOps);
1326            }
1327            const SCEV* InnerMulSum = getAddExpr(InnerMul1,InnerMul2);
1328            const SCEV* OuterMul = getMulExpr(MulOpSCEV, InnerMulSum);
1329            if (Ops.size() == 2) return OuterMul;
1330            Ops.erase(Ops.begin()+Idx);
1331            Ops.erase(Ops.begin()+OtherMulIdx-1);
1332            Ops.push_back(OuterMul);
1333            return getAddExpr(Ops);
1334          }
1335      }
1336    }
1337  }
1338
1339  // If there are any add recurrences in the operands list, see if any other
1340  // added values are loop invariant.  If so, we can fold them into the
1341  // recurrence.
1342  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1343    ++Idx;
1344
1345  // Scan over all recurrences, trying to fold loop invariants into them.
1346  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1347    // Scan all of the other operands to this add and add them to the vector if
1348    // they are loop invariant w.r.t. the recurrence.
1349    SmallVector<const SCEV*, 8> LIOps;
1350    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1351    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1352      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1353        LIOps.push_back(Ops[i]);
1354        Ops.erase(Ops.begin()+i);
1355        --i; --e;
1356      }
1357
1358    // If we found some loop invariants, fold them into the recurrence.
1359    if (!LIOps.empty()) {
1360      //  NLI + LI + {Start,+,Step}  -->  NLI + {LI+Start,+,Step}
1361      LIOps.push_back(AddRec->getStart());
1362
1363      SmallVector<const SCEV*, 4> AddRecOps(AddRec->op_begin(),
1364                                           AddRec->op_end());
1365      AddRecOps[0] = getAddExpr(LIOps);
1366
1367      const SCEV* NewRec = getAddRecExpr(AddRecOps, AddRec->getLoop());
1368      // If all of the other operands were loop invariant, we are done.
1369      if (Ops.size() == 1) return NewRec;
1370
1371      // Otherwise, add the folded AddRec by the non-liv parts.
1372      for (unsigned i = 0;; ++i)
1373        if (Ops[i] == AddRec) {
1374          Ops[i] = NewRec;
1375          break;
1376        }
1377      return getAddExpr(Ops);
1378    }
1379
1380    // Okay, if there weren't any loop invariants to be folded, check to see if
1381    // there are multiple AddRec's with the same loop induction variable being
1382    // added together.  If so, we can fold them.
1383    for (unsigned OtherIdx = Idx+1;
1384         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1385      if (OtherIdx != Idx) {
1386        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1387        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1388          // Other + {A,+,B} + {C,+,D}  -->  Other + {A+C,+,B+D}
1389          SmallVector<const SCEV *, 4> NewOps(AddRec->op_begin(),
1390                                              AddRec->op_end());
1391          for (unsigned i = 0, e = OtherAddRec->getNumOperands(); i != e; ++i) {
1392            if (i >= NewOps.size()) {
1393              NewOps.insert(NewOps.end(), OtherAddRec->op_begin()+i,
1394                            OtherAddRec->op_end());
1395              break;
1396            }
1397            NewOps[i] = getAddExpr(NewOps[i], OtherAddRec->getOperand(i));
1398          }
1399          const SCEV* NewAddRec = getAddRecExpr(NewOps, AddRec->getLoop());
1400
1401          if (Ops.size() == 2) return NewAddRec;
1402
1403          Ops.erase(Ops.begin()+Idx);
1404          Ops.erase(Ops.begin()+OtherIdx-1);
1405          Ops.push_back(NewAddRec);
1406          return getAddExpr(Ops);
1407        }
1408      }
1409
1410    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1411    // next one.
1412  }
1413
1414  // Okay, it looks like we really DO need an add expr.  Check to see if we
1415  // already have one, otherwise create a new one.
1416  FoldingSetNodeID ID;
1417  ID.AddInteger(scAddExpr);
1418  ID.AddInteger(Ops.size());
1419  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1420    ID.AddPointer(Ops[i]);
1421  void *IP = 0;
1422  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1423  SCEV *S = SCEVAllocator.Allocate<SCEVAddExpr>();
1424  new (S) SCEVAddExpr(Ops);
1425  UniqueSCEVs.InsertNode(S, IP);
1426  return S;
1427}
1428
1429
1430/// getMulExpr - Get a canonical multiply expression, or something simpler if
1431/// possible.
1432const SCEV* ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV*> &Ops) {
1433  assert(!Ops.empty() && "Cannot get empty mul!");
1434#ifndef NDEBUG
1435  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1436    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1437           getEffectiveSCEVType(Ops[0]->getType()) &&
1438           "SCEVMulExpr operand types don't match!");
1439#endif
1440
1441  // Sort by complexity, this groups all similar expression types together.
1442  GroupByComplexity(Ops, LI);
1443
1444  // If there are any constants, fold them together.
1445  unsigned Idx = 0;
1446  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1447
1448    // C1*(C2+V) -> C1*C2 + C1*V
1449    if (Ops.size() == 2)
1450      if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
1451        if (Add->getNumOperands() == 2 &&
1452            isa<SCEVConstant>(Add->getOperand(0)))
1453          return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)),
1454                            getMulExpr(LHSC, Add->getOperand(1)));
1455
1456
1457    ++Idx;
1458    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1459      // We found two constants, fold them together!
1460      ConstantInt *Fold = ConstantInt::get(LHSC->getValue()->getValue() *
1461                                           RHSC->getValue()->getValue());
1462      Ops[0] = getConstant(Fold);
1463      Ops.erase(Ops.begin()+1);  // Erase the folded element
1464      if (Ops.size() == 1) return Ops[0];
1465      LHSC = cast<SCEVConstant>(Ops[0]);
1466    }
1467
1468    // If we are left with a constant one being multiplied, strip it off.
1469    if (cast<SCEVConstant>(Ops[0])->getValue()->equalsInt(1)) {
1470      Ops.erase(Ops.begin());
1471      --Idx;
1472    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
1473      // If we have a multiply of zero, it will always be zero.
1474      return Ops[0];
1475    }
1476  }
1477
1478  // Skip over the add expression until we get to a multiply.
1479  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scMulExpr)
1480    ++Idx;
1481
1482  if (Ops.size() == 1)
1483    return Ops[0];
1484
1485  // If there are mul operands inline them all into this expression.
1486  if (Idx < Ops.size()) {
1487    bool DeletedMul = false;
1488    while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
1489      // If we have an mul, expand the mul operands onto the end of the operands
1490      // list.
1491      Ops.insert(Ops.end(), Mul->op_begin(), Mul->op_end());
1492      Ops.erase(Ops.begin()+Idx);
1493      DeletedMul = true;
1494    }
1495
1496    // If we deleted at least one mul, we added operands to the end of the list,
1497    // and they are not necessarily sorted.  Recurse to resort and resimplify
1498    // any operands we just aquired.
1499    if (DeletedMul)
1500      return getMulExpr(Ops);
1501  }
1502
1503  // If there are any add recurrences in the operands list, see if any other
1504  // added values are loop invariant.  If so, we can fold them into the
1505  // recurrence.
1506  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scAddRecExpr)
1507    ++Idx;
1508
1509  // Scan over all recurrences, trying to fold loop invariants into them.
1510  for (; Idx < Ops.size() && isa<SCEVAddRecExpr>(Ops[Idx]); ++Idx) {
1511    // Scan all of the other operands to this mul and add them to the vector if
1512    // they are loop invariant w.r.t. the recurrence.
1513    SmallVector<const SCEV*, 8> LIOps;
1514    const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(Ops[Idx]);
1515    for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1516      if (Ops[i]->isLoopInvariant(AddRec->getLoop())) {
1517        LIOps.push_back(Ops[i]);
1518        Ops.erase(Ops.begin()+i);
1519        --i; --e;
1520      }
1521
1522    // If we found some loop invariants, fold them into the recurrence.
1523    if (!LIOps.empty()) {
1524      //  NLI * LI * {Start,+,Step}  -->  NLI * {LI*Start,+,LI*Step}
1525      SmallVector<const SCEV*, 4> NewOps;
1526      NewOps.reserve(AddRec->getNumOperands());
1527      if (LIOps.size() == 1) {
1528        const SCEV *Scale = LIOps[0];
1529        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i)
1530          NewOps.push_back(getMulExpr(Scale, AddRec->getOperand(i)));
1531      } else {
1532        for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
1533          SmallVector<const SCEV*, 4> MulOps(LIOps.begin(), LIOps.end());
1534          MulOps.push_back(AddRec->getOperand(i));
1535          NewOps.push_back(getMulExpr(MulOps));
1536        }
1537      }
1538
1539      const SCEV* NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
1540
1541      // If all of the other operands were loop invariant, we are done.
1542      if (Ops.size() == 1) return NewRec;
1543
1544      // Otherwise, multiply the folded AddRec by the non-liv parts.
1545      for (unsigned i = 0;; ++i)
1546        if (Ops[i] == AddRec) {
1547          Ops[i] = NewRec;
1548          break;
1549        }
1550      return getMulExpr(Ops);
1551    }
1552
1553    // Okay, if there weren't any loop invariants to be folded, check to see if
1554    // there are multiple AddRec's with the same loop induction variable being
1555    // multiplied together.  If so, we can fold them.
1556    for (unsigned OtherIdx = Idx+1;
1557         OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);++OtherIdx)
1558      if (OtherIdx != Idx) {
1559        const SCEVAddRecExpr *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
1560        if (AddRec->getLoop() == OtherAddRec->getLoop()) {
1561          // F * G  -->  {A,+,B} * {C,+,D}  -->  {A*C,+,F*D + G*B + B*D}
1562          const SCEVAddRecExpr *F = AddRec, *G = OtherAddRec;
1563          const SCEV* NewStart = getMulExpr(F->getStart(),
1564                                                 G->getStart());
1565          const SCEV* B = F->getStepRecurrence(*this);
1566          const SCEV* D = G->getStepRecurrence(*this);
1567          const SCEV* NewStep = getAddExpr(getMulExpr(F, D),
1568                                          getMulExpr(G, B),
1569                                          getMulExpr(B, D));
1570          const SCEV* NewAddRec = getAddRecExpr(NewStart, NewStep,
1571                                               F->getLoop());
1572          if (Ops.size() == 2) return NewAddRec;
1573
1574          Ops.erase(Ops.begin()+Idx);
1575          Ops.erase(Ops.begin()+OtherIdx-1);
1576          Ops.push_back(NewAddRec);
1577          return getMulExpr(Ops);
1578        }
1579      }
1580
1581    // Otherwise couldn't fold anything into this recurrence.  Move onto the
1582    // next one.
1583  }
1584
1585  // Okay, it looks like we really DO need an mul expr.  Check to see if we
1586  // already have one, otherwise create a new one.
1587  FoldingSetNodeID ID;
1588  ID.AddInteger(scMulExpr);
1589  ID.AddInteger(Ops.size());
1590  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1591    ID.AddPointer(Ops[i]);
1592  void *IP = 0;
1593  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1594  SCEV *S = SCEVAllocator.Allocate<SCEVMulExpr>();
1595  new (S) SCEVMulExpr(Ops);
1596  UniqueSCEVs.InsertNode(S, IP);
1597  return S;
1598}
1599
1600/// getUDivExpr - Get a canonical multiply expression, or something simpler if
1601/// possible.
1602const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
1603                                         const SCEV *RHS) {
1604  assert(getEffectiveSCEVType(LHS->getType()) ==
1605         getEffectiveSCEVType(RHS->getType()) &&
1606         "SCEVUDivExpr operand types don't match!");
1607
1608  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) {
1609    if (RHSC->getValue()->equalsInt(1))
1610      return LHS;                            // X udiv 1 --> x
1611    if (RHSC->isZero())
1612      return getIntegerSCEV(0, LHS->getType()); // value is undefined
1613
1614    // Determine if the division can be folded into the operands of
1615    // its operands.
1616    // TODO: Generalize this to non-constants by using known-bits information.
1617    const Type *Ty = LHS->getType();
1618    unsigned LZ = RHSC->getValue()->getValue().countLeadingZeros();
1619    unsigned MaxShiftAmt = getTypeSizeInBits(Ty) - LZ;
1620    // For non-power-of-two values, effectively round the value up to the
1621    // nearest power of two.
1622    if (!RHSC->getValue()->getValue().isPowerOf2())
1623      ++MaxShiftAmt;
1624    const IntegerType *ExtTy =
1625      IntegerType::get(getTypeSizeInBits(Ty) + MaxShiftAmt);
1626    // {X,+,N}/C --> {X/C,+,N/C} if safe and N/C can be folded.
1627    if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS))
1628      if (const SCEVConstant *Step =
1629            dyn_cast<SCEVConstant>(AR->getStepRecurrence(*this)))
1630        if (!Step->getValue()->getValue()
1631              .urem(RHSC->getValue()->getValue()) &&
1632            getZeroExtendExpr(AR, ExtTy) ==
1633            getAddRecExpr(getZeroExtendExpr(AR->getStart(), ExtTy),
1634                          getZeroExtendExpr(Step, ExtTy),
1635                          AR->getLoop())) {
1636          SmallVector<const SCEV*, 4> Operands;
1637          for (unsigned i = 0, e = AR->getNumOperands(); i != e; ++i)
1638            Operands.push_back(getUDivExpr(AR->getOperand(i), RHS));
1639          return getAddRecExpr(Operands, AR->getLoop());
1640        }
1641    // (A*B)/C --> A*(B/C) if safe and B/C can be folded.
1642    if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(LHS)) {
1643      SmallVector<const SCEV*, 4> Operands;
1644      for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i)
1645        Operands.push_back(getZeroExtendExpr(M->getOperand(i), ExtTy));
1646      if (getZeroExtendExpr(M, ExtTy) == getMulExpr(Operands))
1647        // Find an operand that's safely divisible.
1648        for (unsigned i = 0, e = M->getNumOperands(); i != e; ++i) {
1649          const SCEV* Op = M->getOperand(i);
1650          const SCEV* Div = getUDivExpr(Op, RHSC);
1651          if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
1652            const SmallVectorImpl<const SCEV*> &MOperands = M->getOperands();
1653            Operands = SmallVector<const SCEV*, 4>(MOperands.begin(),
1654                                                  MOperands.end());
1655            Operands[i] = Div;
1656            return getMulExpr(Operands);
1657          }
1658        }
1659    }
1660    // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
1661    if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(LHS)) {
1662      SmallVector<const SCEV*, 4> Operands;
1663      for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i)
1664        Operands.push_back(getZeroExtendExpr(A->getOperand(i), ExtTy));
1665      if (getZeroExtendExpr(A, ExtTy) == getAddExpr(Operands)) {
1666        Operands.clear();
1667        for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
1668          const SCEV* Op = getUDivExpr(A->getOperand(i), RHS);
1669          if (isa<SCEVUDivExpr>(Op) || getMulExpr(Op, RHS) != A->getOperand(i))
1670            break;
1671          Operands.push_back(Op);
1672        }
1673        if (Operands.size() == A->getNumOperands())
1674          return getAddExpr(Operands);
1675      }
1676    }
1677
1678    // Fold if both operands are constant.
1679    if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(LHS)) {
1680      Constant *LHSCV = LHSC->getValue();
1681      Constant *RHSCV = RHSC->getValue();
1682      return getConstant(cast<ConstantInt>(ConstantExpr::getUDiv(LHSCV,
1683                                                                 RHSCV)));
1684    }
1685  }
1686
1687  FoldingSetNodeID ID;
1688  ID.AddInteger(scUDivExpr);
1689  ID.AddPointer(LHS);
1690  ID.AddPointer(RHS);
1691  void *IP = 0;
1692  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1693  SCEV *S = SCEVAllocator.Allocate<SCEVUDivExpr>();
1694  new (S) SCEVUDivExpr(LHS, RHS);
1695  UniqueSCEVs.InsertNode(S, IP);
1696  return S;
1697}
1698
1699
1700/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1701/// Simplify the expression as much as possible.
1702const SCEV* ScalarEvolution::getAddRecExpr(const SCEV* Start,
1703                               const SCEV* Step, const Loop *L) {
1704  SmallVector<const SCEV*, 4> Operands;
1705  Operands.push_back(Start);
1706  if (const SCEVAddRecExpr *StepChrec = dyn_cast<SCEVAddRecExpr>(Step))
1707    if (StepChrec->getLoop() == L) {
1708      Operands.insert(Operands.end(), StepChrec->op_begin(),
1709                      StepChrec->op_end());
1710      return getAddRecExpr(Operands, L);
1711    }
1712
1713  Operands.push_back(Step);
1714  return getAddRecExpr(Operands, L);
1715}
1716
1717/// getAddRecExpr - Get an add recurrence expression for the specified loop.
1718/// Simplify the expression as much as possible.
1719const SCEV *
1720ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV*> &Operands,
1721                               const Loop *L) {
1722  if (Operands.size() == 1) return Operands[0];
1723#ifndef NDEBUG
1724  for (unsigned i = 1, e = Operands.size(); i != e; ++i)
1725    assert(getEffectiveSCEVType(Operands[i]->getType()) ==
1726           getEffectiveSCEVType(Operands[0]->getType()) &&
1727           "SCEVAddRecExpr operand types don't match!");
1728#endif
1729
1730  if (Operands.back()->isZero()) {
1731    Operands.pop_back();
1732    return getAddRecExpr(Operands, L);             // {X,+,0}  -->  X
1733  }
1734
1735  // Canonicalize nested AddRecs in by nesting them in order of loop depth.
1736  if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
1737    const Loop* NestedLoop = NestedAR->getLoop();
1738    if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
1739      SmallVector<const SCEV*, 4> NestedOperands(NestedAR->op_begin(),
1740                                                NestedAR->op_end());
1741      Operands[0] = NestedAR->getStart();
1742      // AddRecs require their operands be loop-invariant with respect to their
1743      // loops. Don't perform this transformation if it would break this
1744      // requirement.
1745      bool AllInvariant = true;
1746      for (unsigned i = 0, e = Operands.size(); i != e; ++i)
1747        if (!Operands[i]->isLoopInvariant(L)) {
1748          AllInvariant = false;
1749          break;
1750        }
1751      if (AllInvariant) {
1752        NestedOperands[0] = getAddRecExpr(Operands, L);
1753        AllInvariant = true;
1754        for (unsigned i = 0, e = NestedOperands.size(); i != e; ++i)
1755          if (!NestedOperands[i]->isLoopInvariant(NestedLoop)) {
1756            AllInvariant = false;
1757            break;
1758          }
1759        if (AllInvariant)
1760          // Ok, both add recurrences are valid after the transformation.
1761          return getAddRecExpr(NestedOperands, NestedLoop);
1762      }
1763      // Reset Operands to its original state.
1764      Operands[0] = NestedAR;
1765    }
1766  }
1767
1768  FoldingSetNodeID ID;
1769  ID.AddInteger(scAddRecExpr);
1770  ID.AddInteger(Operands.size());
1771  for (unsigned i = 0, e = Operands.size(); i != e; ++i)
1772    ID.AddPointer(Operands[i]);
1773  ID.AddPointer(L);
1774  void *IP = 0;
1775  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1776  SCEV *S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
1777  new (S) SCEVAddRecExpr(Operands, L);
1778  UniqueSCEVs.InsertNode(S, IP);
1779  return S;
1780}
1781
1782const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS,
1783                                         const SCEV *RHS) {
1784  SmallVector<const SCEV*, 2> Ops;
1785  Ops.push_back(LHS);
1786  Ops.push_back(RHS);
1787  return getSMaxExpr(Ops);
1788}
1789
1790const SCEV*
1791ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV*> &Ops) {
1792  assert(!Ops.empty() && "Cannot get empty smax!");
1793  if (Ops.size() == 1) return Ops[0];
1794#ifndef NDEBUG
1795  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1796    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1797           getEffectiveSCEVType(Ops[0]->getType()) &&
1798           "SCEVSMaxExpr operand types don't match!");
1799#endif
1800
1801  // Sort by complexity, this groups all similar expression types together.
1802  GroupByComplexity(Ops, LI);
1803
1804  // If there are any constants, fold them together.
1805  unsigned Idx = 0;
1806  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1807    ++Idx;
1808    assert(Idx < Ops.size());
1809    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1810      // We found two constants, fold them together!
1811      ConstantInt *Fold = ConstantInt::get(
1812                              APIntOps::smax(LHSC->getValue()->getValue(),
1813                                             RHSC->getValue()->getValue()));
1814      Ops[0] = getConstant(Fold);
1815      Ops.erase(Ops.begin()+1);  // Erase the folded element
1816      if (Ops.size() == 1) return Ops[0];
1817      LHSC = cast<SCEVConstant>(Ops[0]);
1818    }
1819
1820    // If we are left with a constant minimum-int, strip it off.
1821    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(true)) {
1822      Ops.erase(Ops.begin());
1823      --Idx;
1824    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(true)) {
1825      // If we have an smax with a constant maximum-int, it will always be
1826      // maximum-int.
1827      return Ops[0];
1828    }
1829  }
1830
1831  if (Ops.size() == 1) return Ops[0];
1832
1833  // Find the first SMax
1834  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scSMaxExpr)
1835    ++Idx;
1836
1837  // Check to see if one of the operands is an SMax. If so, expand its operands
1838  // onto our operand list, and recurse to simplify.
1839  if (Idx < Ops.size()) {
1840    bool DeletedSMax = false;
1841    while (const SCEVSMaxExpr *SMax = dyn_cast<SCEVSMaxExpr>(Ops[Idx])) {
1842      Ops.insert(Ops.end(), SMax->op_begin(), SMax->op_end());
1843      Ops.erase(Ops.begin()+Idx);
1844      DeletedSMax = true;
1845    }
1846
1847    if (DeletedSMax)
1848      return getSMaxExpr(Ops);
1849  }
1850
1851  // Okay, check to see if the same value occurs in the operand list twice.  If
1852  // so, delete one.  Since we sorted the list, these values are required to
1853  // be adjacent.
1854  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1855    if (Ops[i] == Ops[i+1]) {      //  X smax Y smax Y  -->  X smax Y
1856      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1857      --i; --e;
1858    }
1859
1860  if (Ops.size() == 1) return Ops[0];
1861
1862  assert(!Ops.empty() && "Reduced smax down to nothing!");
1863
1864  // Okay, it looks like we really DO need an smax expr.  Check to see if we
1865  // already have one, otherwise create a new one.
1866  FoldingSetNodeID ID;
1867  ID.AddInteger(scSMaxExpr);
1868  ID.AddInteger(Ops.size());
1869  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1870    ID.AddPointer(Ops[i]);
1871  void *IP = 0;
1872  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1873  SCEV *S = SCEVAllocator.Allocate<SCEVSMaxExpr>();
1874  new (S) SCEVSMaxExpr(Ops);
1875  UniqueSCEVs.InsertNode(S, IP);
1876  return S;
1877}
1878
1879const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS,
1880                                         const SCEV *RHS) {
1881  SmallVector<const SCEV*, 2> Ops;
1882  Ops.push_back(LHS);
1883  Ops.push_back(RHS);
1884  return getUMaxExpr(Ops);
1885}
1886
1887const SCEV*
1888ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV*> &Ops) {
1889  assert(!Ops.empty() && "Cannot get empty umax!");
1890  if (Ops.size() == 1) return Ops[0];
1891#ifndef NDEBUG
1892  for (unsigned i = 1, e = Ops.size(); i != e; ++i)
1893    assert(getEffectiveSCEVType(Ops[i]->getType()) ==
1894           getEffectiveSCEVType(Ops[0]->getType()) &&
1895           "SCEVUMaxExpr operand types don't match!");
1896#endif
1897
1898  // Sort by complexity, this groups all similar expression types together.
1899  GroupByComplexity(Ops, LI);
1900
1901  // If there are any constants, fold them together.
1902  unsigned Idx = 0;
1903  if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
1904    ++Idx;
1905    assert(Idx < Ops.size());
1906    while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
1907      // We found two constants, fold them together!
1908      ConstantInt *Fold = ConstantInt::get(
1909                              APIntOps::umax(LHSC->getValue()->getValue(),
1910                                             RHSC->getValue()->getValue()));
1911      Ops[0] = getConstant(Fold);
1912      Ops.erase(Ops.begin()+1);  // Erase the folded element
1913      if (Ops.size() == 1) return Ops[0];
1914      LHSC = cast<SCEVConstant>(Ops[0]);
1915    }
1916
1917    // If we are left with a constant minimum-int, strip it off.
1918    if (cast<SCEVConstant>(Ops[0])->getValue()->isMinValue(false)) {
1919      Ops.erase(Ops.begin());
1920      --Idx;
1921    } else if (cast<SCEVConstant>(Ops[0])->getValue()->isMaxValue(false)) {
1922      // If we have an umax with a constant maximum-int, it will always be
1923      // maximum-int.
1924      return Ops[0];
1925    }
1926  }
1927
1928  if (Ops.size() == 1) return Ops[0];
1929
1930  // Find the first UMax
1931  while (Idx < Ops.size() && Ops[Idx]->getSCEVType() < scUMaxExpr)
1932    ++Idx;
1933
1934  // Check to see if one of the operands is a UMax. If so, expand its operands
1935  // onto our operand list, and recurse to simplify.
1936  if (Idx < Ops.size()) {
1937    bool DeletedUMax = false;
1938    while (const SCEVUMaxExpr *UMax = dyn_cast<SCEVUMaxExpr>(Ops[Idx])) {
1939      Ops.insert(Ops.end(), UMax->op_begin(), UMax->op_end());
1940      Ops.erase(Ops.begin()+Idx);
1941      DeletedUMax = true;
1942    }
1943
1944    if (DeletedUMax)
1945      return getUMaxExpr(Ops);
1946  }
1947
1948  // Okay, check to see if the same value occurs in the operand list twice.  If
1949  // so, delete one.  Since we sorted the list, these values are required to
1950  // be adjacent.
1951  for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
1952    if (Ops[i] == Ops[i+1]) {      //  X umax Y umax Y  -->  X umax Y
1953      Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
1954      --i; --e;
1955    }
1956
1957  if (Ops.size() == 1) return Ops[0];
1958
1959  assert(!Ops.empty() && "Reduced umax down to nothing!");
1960
1961  // Okay, it looks like we really DO need a umax expr.  Check to see if we
1962  // already have one, otherwise create a new one.
1963  FoldingSetNodeID ID;
1964  ID.AddInteger(scUMaxExpr);
1965  ID.AddInteger(Ops.size());
1966  for (unsigned i = 0, e = Ops.size(); i != e; ++i)
1967    ID.AddPointer(Ops[i]);
1968  void *IP = 0;
1969  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1970  SCEV *S = SCEVAllocator.Allocate<SCEVUMaxExpr>();
1971  new (S) SCEVUMaxExpr(Ops);
1972  UniqueSCEVs.InsertNode(S, IP);
1973  return S;
1974}
1975
1976const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
1977                                         const SCEV *RHS) {
1978  // ~smax(~x, ~y) == smin(x, y).
1979  return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
1980}
1981
1982const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
1983                                         const SCEV *RHS) {
1984  // ~umax(~x, ~y) == umin(x, y)
1985  return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
1986}
1987
1988const SCEV* ScalarEvolution::getUnknown(Value *V) {
1989  // Don't attempt to do anything other than create a SCEVUnknown object
1990  // here.  createSCEV only calls getUnknown after checking for all other
1991  // interesting possibilities, and any other code that calls getUnknown
1992  // is doing so in order to hide a value from SCEV canonicalization.
1993
1994  FoldingSetNodeID ID;
1995  ID.AddInteger(scUnknown);
1996  ID.AddPointer(V);
1997  void *IP = 0;
1998  if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
1999  SCEV *S = SCEVAllocator.Allocate<SCEVUnknown>();
2000  new (S) SCEVUnknown(V);
2001  UniqueSCEVs.InsertNode(S, IP);
2002  return S;
2003}
2004
2005//===----------------------------------------------------------------------===//
2006//            Basic SCEV Analysis and PHI Idiom Recognition Code
2007//
2008
2009/// isSCEVable - Test if values of the given type are analyzable within
2010/// the SCEV framework. This primarily includes integer types, and it
2011/// can optionally include pointer types if the ScalarEvolution class
2012/// has access to target-specific information.
2013bool ScalarEvolution::isSCEVable(const Type *Ty) const {
2014  // Integers are always SCEVable.
2015  if (Ty->isInteger())
2016    return true;
2017
2018  // Pointers are SCEVable if TargetData information is available
2019  // to provide pointer size information.
2020  if (isa<PointerType>(Ty))
2021    return TD != NULL;
2022
2023  // Otherwise it's not SCEVable.
2024  return false;
2025}
2026
2027/// getTypeSizeInBits - Return the size in bits of the specified type,
2028/// for which isSCEVable must return true.
2029uint64_t ScalarEvolution::getTypeSizeInBits(const Type *Ty) const {
2030  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2031
2032  // If we have a TargetData, use it!
2033  if (TD)
2034    return TD->getTypeSizeInBits(Ty);
2035
2036  // Otherwise, we support only integer types.
2037  assert(Ty->isInteger() && "isSCEVable permitted a non-SCEVable type!");
2038  return Ty->getPrimitiveSizeInBits();
2039}
2040
2041/// getEffectiveSCEVType - Return a type with the same bitwidth as
2042/// the given type and which represents how SCEV will treat the given
2043/// type, for which isSCEVable must return true. For pointer types,
2044/// this is the pointer-sized integer type.
2045const Type *ScalarEvolution::getEffectiveSCEVType(const Type *Ty) const {
2046  assert(isSCEVable(Ty) && "Type is not SCEVable!");
2047
2048  if (Ty->isInteger())
2049    return Ty;
2050
2051  assert(isa<PointerType>(Ty) && "Unexpected non-pointer non-integer type!");
2052  return TD->getIntPtrType();
2053}
2054
2055const SCEV* ScalarEvolution::getCouldNotCompute() {
2056  return &CouldNotCompute;
2057}
2058
2059/// hasSCEV - Return true if the SCEV for this value has already been
2060/// computed.
2061bool ScalarEvolution::hasSCEV(Value *V) const {
2062  return Scalars.count(V);
2063}
2064
2065/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the
2066/// expression and create a new one.
2067const SCEV* ScalarEvolution::getSCEV(Value *V) {
2068  assert(isSCEVable(V->getType()) && "Value is not SCEVable!");
2069
2070  std::map<SCEVCallbackVH, const SCEV*>::iterator I = Scalars.find(V);
2071  if (I != Scalars.end()) return I->second;
2072  const SCEV* S = createSCEV(V);
2073  Scalars.insert(std::make_pair(SCEVCallbackVH(V, this), S));
2074  return S;
2075}
2076
2077/// getIntegerSCEV - Given a SCEVable type, create a constant for the
2078/// specified signed integer value and return a SCEV for the constant.
2079const SCEV* ScalarEvolution::getIntegerSCEV(int Val, const Type *Ty) {
2080  const IntegerType *ITy = cast<IntegerType>(getEffectiveSCEVType(Ty));
2081  return getConstant(ConstantInt::get(ITy, Val));
2082}
2083
2084/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V
2085///
2086const SCEV* ScalarEvolution::getNegativeSCEV(const SCEV* V) {
2087  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2088    return getConstant(cast<ConstantInt>(ConstantExpr::getNeg(VC->getValue())));
2089
2090  const Type *Ty = V->getType();
2091  Ty = getEffectiveSCEVType(Ty);
2092  return getMulExpr(V, getConstant(ConstantInt::getAllOnesValue(Ty)));
2093}
2094
2095/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V
2096const SCEV* ScalarEvolution::getNotSCEV(const SCEV* V) {
2097  if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V))
2098    return getConstant(cast<ConstantInt>(ConstantExpr::getNot(VC->getValue())));
2099
2100  const Type *Ty = V->getType();
2101  Ty = getEffectiveSCEVType(Ty);
2102  const SCEV* AllOnes = getConstant(ConstantInt::getAllOnesValue(Ty));
2103  return getMinusSCEV(AllOnes, V);
2104}
2105
2106/// getMinusSCEV - Return a SCEV corresponding to LHS - RHS.
2107///
2108const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS,
2109                                          const SCEV *RHS) {
2110  // X - Y --> X + -Y
2111  return getAddExpr(LHS, getNegativeSCEV(RHS));
2112}
2113
2114/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the
2115/// input value to the specified type.  If the type must be extended, it is zero
2116/// extended.
2117const SCEV*
2118ScalarEvolution::getTruncateOrZeroExtend(const SCEV* V,
2119                                         const Type *Ty) {
2120  const Type *SrcTy = V->getType();
2121  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2122         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2123         "Cannot truncate or zero extend with non-integer arguments!");
2124  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2125    return V;  // No conversion
2126  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2127    return getTruncateExpr(V, Ty);
2128  return getZeroExtendExpr(V, Ty);
2129}
2130
2131/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the
2132/// input value to the specified type.  If the type must be extended, it is sign
2133/// extended.
2134const SCEV*
2135ScalarEvolution::getTruncateOrSignExtend(const SCEV* V,
2136                                         const Type *Ty) {
2137  const Type *SrcTy = V->getType();
2138  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2139         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2140         "Cannot truncate or zero extend with non-integer arguments!");
2141  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2142    return V;  // No conversion
2143  if (getTypeSizeInBits(SrcTy) > getTypeSizeInBits(Ty))
2144    return getTruncateExpr(V, Ty);
2145  return getSignExtendExpr(V, Ty);
2146}
2147
2148/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the
2149/// input value to the specified type.  If the type must be extended, it is zero
2150/// extended.  The conversion must not be narrowing.
2151const SCEV*
2152ScalarEvolution::getNoopOrZeroExtend(const SCEV* V, const Type *Ty) {
2153  const Type *SrcTy = V->getType();
2154  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2155         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2156         "Cannot noop or zero extend with non-integer arguments!");
2157  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2158         "getNoopOrZeroExtend cannot truncate!");
2159  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2160    return V;  // No conversion
2161  return getZeroExtendExpr(V, Ty);
2162}
2163
2164/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the
2165/// input value to the specified type.  If the type must be extended, it is sign
2166/// extended.  The conversion must not be narrowing.
2167const SCEV*
2168ScalarEvolution::getNoopOrSignExtend(const SCEV* V, const Type *Ty) {
2169  const Type *SrcTy = V->getType();
2170  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2171         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2172         "Cannot noop or sign extend with non-integer arguments!");
2173  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2174         "getNoopOrSignExtend cannot truncate!");
2175  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2176    return V;  // No conversion
2177  return getSignExtendExpr(V, Ty);
2178}
2179
2180/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of
2181/// the input value to the specified type. If the type must be extended,
2182/// it is extended with unspecified bits. The conversion must not be
2183/// narrowing.
2184const SCEV*
2185ScalarEvolution::getNoopOrAnyExtend(const SCEV* V, const Type *Ty) {
2186  const Type *SrcTy = V->getType();
2187  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2188         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2189         "Cannot noop or any extend with non-integer arguments!");
2190  assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
2191         "getNoopOrAnyExtend cannot truncate!");
2192  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2193    return V;  // No conversion
2194  return getAnyExtendExpr(V, Ty);
2195}
2196
2197/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the
2198/// input value to the specified type.  The conversion must not be widening.
2199const SCEV*
2200ScalarEvolution::getTruncateOrNoop(const SCEV* V, const Type *Ty) {
2201  const Type *SrcTy = V->getType();
2202  assert((SrcTy->isInteger() || (TD && isa<PointerType>(SrcTy))) &&
2203         (Ty->isInteger() || (TD && isa<PointerType>(Ty))) &&
2204         "Cannot truncate or noop with non-integer arguments!");
2205  assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
2206         "getTruncateOrNoop cannot extend!");
2207  if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
2208    return V;  // No conversion
2209  return getTruncateExpr(V, Ty);
2210}
2211
2212/// getUMaxFromMismatchedTypes - Promote the operands to the wider of
2213/// the types using zero-extension, and then perform a umax operation
2214/// with them.
2215const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
2216                                                        const SCEV *RHS) {
2217  const SCEV* PromotedLHS = LHS;
2218  const SCEV* PromotedRHS = RHS;
2219
2220  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2221    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2222  else
2223    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2224
2225  return getUMaxExpr(PromotedLHS, PromotedRHS);
2226}
2227
2228/// getUMinFromMismatchedTypes - Promote the operands to the wider of
2229/// the types using zero-extension, and then perform a umin operation
2230/// with them.
2231const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
2232                                                        const SCEV *RHS) {
2233  const SCEV* PromotedLHS = LHS;
2234  const SCEV* PromotedRHS = RHS;
2235
2236  if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
2237    PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
2238  else
2239    PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
2240
2241  return getUMinExpr(PromotedLHS, PromotedRHS);
2242}
2243
2244/// ReplaceSymbolicValueWithConcrete - This looks up the computed SCEV value for
2245/// the specified instruction and replaces any references to the symbolic value
2246/// SymName with the specified value.  This is used during PHI resolution.
2247void
2248ScalarEvolution::ReplaceSymbolicValueWithConcrete(Instruction *I,
2249                                                  const SCEV *SymName,
2250                                                  const SCEV *NewVal) {
2251  std::map<SCEVCallbackVH, const SCEV*>::iterator SI =
2252    Scalars.find(SCEVCallbackVH(I, this));
2253  if (SI == Scalars.end()) return;
2254
2255  const SCEV* NV =
2256    SI->second->replaceSymbolicValuesWithConcrete(SymName, NewVal, *this);
2257  if (NV == SI->second) return;  // No change.
2258
2259  SI->second = NV;       // Update the scalars map!
2260
2261  // Any instruction values that use this instruction might also need to be
2262  // updated!
2263  for (Value::use_iterator UI = I->use_begin(), E = I->use_end();
2264       UI != E; ++UI)
2265    ReplaceSymbolicValueWithConcrete(cast<Instruction>(*UI), SymName, NewVal);
2266}
2267
2268/// createNodeForPHI - PHI nodes have two cases.  Either the PHI node exists in
2269/// a loop header, making it a potential recurrence, or it doesn't.
2270///
2271const SCEV* ScalarEvolution::createNodeForPHI(PHINode *PN) {
2272  if (PN->getNumIncomingValues() == 2)  // The loops have been canonicalized.
2273    if (const Loop *L = LI->getLoopFor(PN->getParent()))
2274      if (L->getHeader() == PN->getParent()) {
2275        // If it lives in the loop header, it has two incoming values, one
2276        // from outside the loop, and one from inside.
2277        unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0));
2278        unsigned BackEdge     = IncomingEdge^1;
2279
2280        // While we are analyzing this PHI node, handle its value symbolically.
2281        const SCEV* SymbolicName = getUnknown(PN);
2282        assert(Scalars.find(PN) == Scalars.end() &&
2283               "PHI node already processed?");
2284        Scalars.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName));
2285
2286        // Using this symbolic name for the PHI, analyze the value coming around
2287        // the back-edge.
2288        const SCEV* BEValue = getSCEV(PN->getIncomingValue(BackEdge));
2289
2290        // NOTE: If BEValue is loop invariant, we know that the PHI node just
2291        // has a special value for the first iteration of the loop.
2292
2293        // If the value coming around the backedge is an add with the symbolic
2294        // value we just inserted, then we found a simple induction variable!
2295        if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(BEValue)) {
2296          // If there is a single occurrence of the symbolic value, replace it
2297          // with a recurrence.
2298          unsigned FoundIndex = Add->getNumOperands();
2299          for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2300            if (Add->getOperand(i) == SymbolicName)
2301              if (FoundIndex == e) {
2302                FoundIndex = i;
2303                break;
2304              }
2305
2306          if (FoundIndex != Add->getNumOperands()) {
2307            // Create an add with everything but the specified operand.
2308            SmallVector<const SCEV*, 8> Ops;
2309            for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)
2310              if (i != FoundIndex)
2311                Ops.push_back(Add->getOperand(i));
2312            const SCEV* Accum = getAddExpr(Ops);
2313
2314            // This is not a valid addrec if the step amount is varying each
2315            // loop iteration, but is not itself an addrec in this loop.
2316            if (Accum->isLoopInvariant(L) ||
2317                (isa<SCEVAddRecExpr>(Accum) &&
2318                 cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
2319              const SCEV *StartVal =
2320                getSCEV(PN->getIncomingValue(IncomingEdge));
2321              const SCEV *PHISCEV =
2322                getAddRecExpr(StartVal, Accum, L);
2323
2324              // Okay, for the entire analysis of this edge we assumed the PHI
2325              // to be symbolic.  We now need to go back and update all of the
2326              // entries for the scalars that use the PHI (except for the PHI
2327              // itself) to use the new analyzed value instead of the "symbolic"
2328              // value.
2329              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2330              return PHISCEV;
2331            }
2332          }
2333        } else if (const SCEVAddRecExpr *AddRec =
2334                     dyn_cast<SCEVAddRecExpr>(BEValue)) {
2335          // Otherwise, this could be a loop like this:
2336          //     i = 0;  for (j = 1; ..; ++j) { ....  i = j; }
2337          // In this case, j = {1,+,1}  and BEValue is j.
2338          // Because the other in-value of i (0) fits the evolution of BEValue
2339          // i really is an addrec evolution.
2340          if (AddRec->getLoop() == L && AddRec->isAffine()) {
2341            const SCEV* StartVal = getSCEV(PN->getIncomingValue(IncomingEdge));
2342
2343            // If StartVal = j.start - j.stride, we can use StartVal as the
2344            // initial step of the addrec evolution.
2345            if (StartVal == getMinusSCEV(AddRec->getOperand(0),
2346                                            AddRec->getOperand(1))) {
2347              const SCEV* PHISCEV =
2348                 getAddRecExpr(StartVal, AddRec->getOperand(1), L);
2349
2350              // Okay, for the entire analysis of this edge we assumed the PHI
2351              // to be symbolic.  We now need to go back and update all of the
2352              // entries for the scalars that use the PHI (except for the PHI
2353              // itself) to use the new analyzed value instead of the "symbolic"
2354              // value.
2355              ReplaceSymbolicValueWithConcrete(PN, SymbolicName, PHISCEV);
2356              return PHISCEV;
2357            }
2358          }
2359        }
2360
2361        return SymbolicName;
2362      }
2363
2364  // If it's not a loop phi, we can't handle it yet.
2365  return getUnknown(PN);
2366}
2367
2368/// createNodeForGEP - Expand GEP instructions into add and multiply
2369/// operations. This allows them to be analyzed by regular SCEV code.
2370///
2371const SCEV* ScalarEvolution::createNodeForGEP(User *GEP) {
2372
2373  const Type *IntPtrTy = TD->getIntPtrType();
2374  Value *Base = GEP->getOperand(0);
2375  // Don't attempt to analyze GEPs over unsized objects.
2376  if (!cast<PointerType>(Base->getType())->getElementType()->isSized())
2377    return getUnknown(GEP);
2378  const SCEV* TotalOffset = getIntegerSCEV(0, IntPtrTy);
2379  gep_type_iterator GTI = gep_type_begin(GEP);
2380  for (GetElementPtrInst::op_iterator I = next(GEP->op_begin()),
2381                                      E = GEP->op_end();
2382       I != E; ++I) {
2383    Value *Index = *I;
2384    // Compute the (potentially symbolic) offset in bytes for this index.
2385    if (const StructType *STy = dyn_cast<StructType>(*GTI++)) {
2386      // For a struct, add the member offset.
2387      const StructLayout &SL = *TD->getStructLayout(STy);
2388      unsigned FieldNo = cast<ConstantInt>(Index)->getZExtValue();
2389      uint64_t Offset = SL.getElementOffset(FieldNo);
2390      TotalOffset = getAddExpr(TotalOffset,
2391                                  getIntegerSCEV(Offset, IntPtrTy));
2392    } else {
2393      // For an array, add the element offset, explicitly scaled.
2394      const SCEV* LocalOffset = getSCEV(Index);
2395      if (!isa<PointerType>(LocalOffset->getType()))
2396        // Getelementptr indicies are signed.
2397        LocalOffset = getTruncateOrSignExtend(LocalOffset,
2398                                              IntPtrTy);
2399      LocalOffset =
2400        getMulExpr(LocalOffset,
2401                   getIntegerSCEV(TD->getTypeAllocSize(*GTI),
2402                                  IntPtrTy));
2403      TotalOffset = getAddExpr(TotalOffset, LocalOffset);
2404    }
2405  }
2406  return getAddExpr(getSCEV(Base), TotalOffset);
2407}
2408
2409/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is
2410/// guaranteed to end in (at every loop iteration).  It is, at the same time,
2411/// the minimum number of times S is divisible by 2.  For example, given {4,+,8}
2412/// it returns 2.  If S is guaranteed to be 0, it returns the bitwidth of S.
2413uint32_t
2414ScalarEvolution::GetMinTrailingZeros(const SCEV* S) {
2415  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2416    return C->getValue()->getValue().countTrailingZeros();
2417
2418  if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
2419    return std::min(GetMinTrailingZeros(T->getOperand()),
2420                    (uint32_t)getTypeSizeInBits(T->getType()));
2421
2422  if (const SCEVZeroExtendExpr *E = dyn_cast<SCEVZeroExtendExpr>(S)) {
2423    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2424    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2425             getTypeSizeInBits(E->getType()) : OpRes;
2426  }
2427
2428  if (const SCEVSignExtendExpr *E = dyn_cast<SCEVSignExtendExpr>(S)) {
2429    uint32_t OpRes = GetMinTrailingZeros(E->getOperand());
2430    return OpRes == getTypeSizeInBits(E->getOperand()->getType()) ?
2431             getTypeSizeInBits(E->getType()) : OpRes;
2432  }
2433
2434  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2435    // The result is the min of all operands results.
2436    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2437    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2438      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2439    return MinOpRes;
2440  }
2441
2442  if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S)) {
2443    // The result is the sum of all operands results.
2444    uint32_t SumOpRes = GetMinTrailingZeros(M->getOperand(0));
2445    uint32_t BitWidth = getTypeSizeInBits(M->getType());
2446    for (unsigned i = 1, e = M->getNumOperands();
2447         SumOpRes != BitWidth && i != e; ++i)
2448      SumOpRes = std::min(SumOpRes + GetMinTrailingZeros(M->getOperand(i)),
2449                          BitWidth);
2450    return SumOpRes;
2451  }
2452
2453  if (const SCEVAddRecExpr *A = dyn_cast<SCEVAddRecExpr>(S)) {
2454    // The result is the min of all operands results.
2455    uint32_t MinOpRes = GetMinTrailingZeros(A->getOperand(0));
2456    for (unsigned i = 1, e = A->getNumOperands(); MinOpRes && i != e; ++i)
2457      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(A->getOperand(i)));
2458    return MinOpRes;
2459  }
2460
2461  if (const SCEVSMaxExpr *M = dyn_cast<SCEVSMaxExpr>(S)) {
2462    // The result is the min of all operands results.
2463    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2464    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2465      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2466    return MinOpRes;
2467  }
2468
2469  if (const SCEVUMaxExpr *M = dyn_cast<SCEVUMaxExpr>(S)) {
2470    // The result is the min of all operands results.
2471    uint32_t MinOpRes = GetMinTrailingZeros(M->getOperand(0));
2472    for (unsigned i = 1, e = M->getNumOperands(); MinOpRes && i != e; ++i)
2473      MinOpRes = std::min(MinOpRes, GetMinTrailingZeros(M->getOperand(i)));
2474    return MinOpRes;
2475  }
2476
2477  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2478    // For a SCEVUnknown, ask ValueTracking.
2479    unsigned BitWidth = getTypeSizeInBits(U->getType());
2480    APInt Mask = APInt::getAllOnesValue(BitWidth);
2481    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2482    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones);
2483    return Zeros.countTrailingOnes();
2484  }
2485
2486  // SCEVUDivExpr
2487  return 0;
2488}
2489
2490uint32_t
2491ScalarEvolution::GetMinLeadingZeros(const SCEV* S) {
2492  // TODO: Handle other SCEV expression types here.
2493
2494  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
2495    return C->getValue()->getValue().countLeadingZeros();
2496
2497  if (const SCEVZeroExtendExpr *C = dyn_cast<SCEVZeroExtendExpr>(S)) {
2498    // A zero-extension cast adds zero bits.
2499    return GetMinLeadingZeros(C->getOperand()) +
2500           (getTypeSizeInBits(C->getType()) -
2501            getTypeSizeInBits(C->getOperand()->getType()));
2502  }
2503
2504  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2505    // For a SCEVUnknown, ask ValueTracking.
2506    unsigned BitWidth = getTypeSizeInBits(U->getType());
2507    APInt Mask = APInt::getAllOnesValue(BitWidth);
2508    APInt Zeros(BitWidth, 0), Ones(BitWidth, 0);
2509    ComputeMaskedBits(U->getValue(), Mask, Zeros, Ones, TD);
2510    return Zeros.countLeadingOnes();
2511  }
2512
2513  return 1;
2514}
2515
2516uint32_t
2517ScalarEvolution::GetMinSignBits(const SCEV* S) {
2518  // TODO: Handle other SCEV expression types here.
2519
2520  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
2521    const APInt &A = C->getValue()->getValue();
2522    return A.isNegative() ? A.countLeadingOnes() :
2523                            A.countLeadingZeros();
2524  }
2525
2526  if (const SCEVSignExtendExpr *C = dyn_cast<SCEVSignExtendExpr>(S)) {
2527    // A sign-extension cast adds sign bits.
2528    return GetMinSignBits(C->getOperand()) +
2529           (getTypeSizeInBits(C->getType()) -
2530            getTypeSizeInBits(C->getOperand()->getType()));
2531  }
2532
2533  if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(S)) {
2534    unsigned BitWidth = getTypeSizeInBits(A->getType());
2535
2536    // Special case decrementing a value (ADD X, -1):
2537    if (const SCEVConstant *CRHS = dyn_cast<SCEVConstant>(A->getOperand(0)))
2538      if (CRHS->isAllOnesValue()) {
2539        SmallVector<const SCEV *, 4> OtherOps(A->op_begin() + 1, A->op_end());
2540        const SCEV *OtherOpsAdd = getAddExpr(OtherOps);
2541        unsigned LZ = GetMinLeadingZeros(OtherOpsAdd);
2542
2543        // If the input is known to be 0 or 1, the output is 0/-1, which is all
2544        // sign bits set.
2545        if (LZ == BitWidth - 1)
2546          return BitWidth;
2547
2548        // If we are subtracting one from a positive number, there is no carry
2549        // out of the result.
2550        if (LZ > 0)
2551          return GetMinSignBits(OtherOpsAdd);
2552      }
2553
2554    // Add can have at most one carry bit.  Thus we know that the output
2555    // is, at worst, one more bit than the inputs.
2556    unsigned Min = BitWidth;
2557    for (unsigned i = 0, e = A->getNumOperands(); i != e; ++i) {
2558      unsigned N = GetMinSignBits(A->getOperand(i));
2559      Min = std::min(Min, N) - 1;
2560      if (Min == 0) return 1;
2561    }
2562    return 1;
2563  }
2564
2565  if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
2566    // For a SCEVUnknown, ask ValueTracking.
2567    return ComputeNumSignBits(U->getValue(), TD);
2568  }
2569
2570  return 1;
2571}
2572
2573/// createSCEV - We know that there is no SCEV for the specified value.
2574/// Analyze the expression.
2575///
2576const SCEV* ScalarEvolution::createSCEV(Value *V) {
2577  if (!isSCEVable(V->getType()))
2578    return getUnknown(V);
2579
2580  unsigned Opcode = Instruction::UserOp1;
2581  if (Instruction *I = dyn_cast<Instruction>(V))
2582    Opcode = I->getOpcode();
2583  else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V))
2584    Opcode = CE->getOpcode();
2585  else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
2586    return getConstant(CI);
2587  else if (isa<ConstantPointerNull>(V))
2588    return getIntegerSCEV(0, V->getType());
2589  else if (isa<UndefValue>(V))
2590    return getIntegerSCEV(0, V->getType());
2591  else
2592    return getUnknown(V);
2593
2594  User *U = cast<User>(V);
2595  switch (Opcode) {
2596  case Instruction::Add:
2597    return getAddExpr(getSCEV(U->getOperand(0)),
2598                      getSCEV(U->getOperand(1)));
2599  case Instruction::Mul:
2600    return getMulExpr(getSCEV(U->getOperand(0)),
2601                      getSCEV(U->getOperand(1)));
2602  case Instruction::UDiv:
2603    return getUDivExpr(getSCEV(U->getOperand(0)),
2604                       getSCEV(U->getOperand(1)));
2605  case Instruction::Sub:
2606    return getMinusSCEV(getSCEV(U->getOperand(0)),
2607                        getSCEV(U->getOperand(1)));
2608  case Instruction::And:
2609    // For an expression like x&255 that merely masks off the high bits,
2610    // use zext(trunc(x)) as the SCEV expression.
2611    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2612      if (CI->isNullValue())
2613        return getSCEV(U->getOperand(1));
2614      if (CI->isAllOnesValue())
2615        return getSCEV(U->getOperand(0));
2616      const APInt &A = CI->getValue();
2617
2618      // Instcombine's ShrinkDemandedConstant may strip bits out of
2619      // constants, obscuring what would otherwise be a low-bits mask.
2620      // Use ComputeMaskedBits to compute what ShrinkDemandedConstant
2621      // knew about to reconstruct a low-bits mask value.
2622      unsigned LZ = A.countLeadingZeros();
2623      unsigned BitWidth = A.getBitWidth();
2624      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
2625      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
2626      ComputeMaskedBits(U->getOperand(0), AllOnes, KnownZero, KnownOne, TD);
2627
2628      APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ);
2629
2630      if (LZ != 0 && !((~A & ~KnownZero) & EffectiveMask))
2631        return
2632          getZeroExtendExpr(getTruncateExpr(getSCEV(U->getOperand(0)),
2633                                            IntegerType::get(BitWidth - LZ)),
2634                            U->getType());
2635    }
2636    break;
2637
2638  case Instruction::Or:
2639    // If the RHS of the Or is a constant, we may have something like:
2640    // X*4+1 which got turned into X*4|1.  Handle this as an Add so loop
2641    // optimizations will transparently handle this case.
2642    //
2643    // In order for this transformation to be safe, the LHS must be of the
2644    // form X*(2^n) and the Or constant must be less than 2^n.
2645    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2646      const SCEV* LHS = getSCEV(U->getOperand(0));
2647      const APInt &CIVal = CI->getValue();
2648      if (GetMinTrailingZeros(LHS) >=
2649          (CIVal.getBitWidth() - CIVal.countLeadingZeros()))
2650        return getAddExpr(LHS, getSCEV(U->getOperand(1)));
2651    }
2652    break;
2653  case Instruction::Xor:
2654    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) {
2655      // If the RHS of the xor is a signbit, then this is just an add.
2656      // Instcombine turns add of signbit into xor as a strength reduction step.
2657      if (CI->getValue().isSignBit())
2658        return getAddExpr(getSCEV(U->getOperand(0)),
2659                          getSCEV(U->getOperand(1)));
2660
2661      // If the RHS of xor is -1, then this is a not operation.
2662      if (CI->isAllOnesValue())
2663        return getNotSCEV(getSCEV(U->getOperand(0)));
2664
2665      // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask.
2666      // This is a variant of the check for xor with -1, and it handles
2667      // the case where instcombine has trimmed non-demanded bits out
2668      // of an xor with -1.
2669      if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0)))
2670        if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1)))
2671          if (BO->getOpcode() == Instruction::And &&
2672              LCI->getValue() == CI->getValue())
2673            if (const SCEVZeroExtendExpr *Z =
2674                  dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) {
2675              const Type *UTy = U->getType();
2676              const SCEV* Z0 = Z->getOperand();
2677              const Type *Z0Ty = Z0->getType();
2678              unsigned Z0TySize = getTypeSizeInBits(Z0Ty);
2679
2680              // If C is a low-bits mask, the zero extend is zerving to
2681              // mask off the high bits. Complement the operand and
2682              // re-apply the zext.
2683              if (APIntOps::isMask(Z0TySize, CI->getValue()))
2684                return getZeroExtendExpr(getNotSCEV(Z0), UTy);
2685
2686              // If C is a single bit, it may be in the sign-bit position
2687              // before the zero-extend. In this case, represent the xor
2688              // using an add, which is equivalent, and re-apply the zext.
2689              APInt Trunc = APInt(CI->getValue()).trunc(Z0TySize);
2690              if (APInt(Trunc).zext(getTypeSizeInBits(UTy)) == CI->getValue() &&
2691                  Trunc.isSignBit())
2692                return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)),
2693                                         UTy);
2694            }
2695    }
2696    break;
2697
2698  case Instruction::Shl:
2699    // Turn shift left of a constant amount into a multiply.
2700    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2701      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2702      Constant *X = ConstantInt::get(
2703        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2704      return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2705    }
2706    break;
2707
2708  case Instruction::LShr:
2709    // Turn logical shift right of a constant into a unsigned divide.
2710    if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) {
2711      uint32_t BitWidth = cast<IntegerType>(V->getType())->getBitWidth();
2712      Constant *X = ConstantInt::get(
2713        APInt(BitWidth, 1).shl(SA->getLimitedValue(BitWidth)));
2714      return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X));
2715    }
2716    break;
2717
2718  case Instruction::AShr:
2719    // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
2720    if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1)))
2721      if (Instruction *L = dyn_cast<Instruction>(U->getOperand(0)))
2722        if (L->getOpcode() == Instruction::Shl &&
2723            L->getOperand(1) == U->getOperand(1)) {
2724          unsigned BitWidth = getTypeSizeInBits(U->getType());
2725          uint64_t Amt = BitWidth - CI->getZExtValue();
2726          if (Amt == BitWidth)
2727            return getSCEV(L->getOperand(0));       // shift by zero --> noop
2728          if (Amt > BitWidth)
2729            return getIntegerSCEV(0, U->getType()); // value is undefined
2730          return
2731            getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)),
2732                                                      IntegerType::get(Amt)),
2733                                 U->getType());
2734        }
2735    break;
2736
2737  case Instruction::Trunc:
2738    return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType());
2739
2740  case Instruction::ZExt:
2741    return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2742
2743  case Instruction::SExt:
2744    return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());
2745
2746  case Instruction::BitCast:
2747    // BitCasts are no-op casts so we just eliminate the cast.
2748    if (isSCEVable(U->getType()) && isSCEVable(U->getOperand(0)->getType()))
2749      return getSCEV(U->getOperand(0));
2750    break;
2751
2752  case Instruction::IntToPtr:
2753    if (!TD) break; // Without TD we can't analyze pointers.
2754    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2755                                   TD->getIntPtrType());
2756
2757  case Instruction::PtrToInt:
2758    if (!TD) break; // Without TD we can't analyze pointers.
2759    return getTruncateOrZeroExtend(getSCEV(U->getOperand(0)),
2760                                   U->getType());
2761
2762  case Instruction::GetElementPtr:
2763    if (!TD) break; // Without TD we can't analyze pointers.
2764    return createNodeForGEP(U);
2765
2766  case Instruction::PHI:
2767    return createNodeForPHI(cast<PHINode>(U));
2768
2769  case Instruction::Select:
2770    // This could be a smax or umax that was lowered earlier.
2771    // Try to recover it.
2772    if (ICmpInst *ICI = dyn_cast<ICmpInst>(U->getOperand(0))) {
2773      Value *LHS = ICI->getOperand(0);
2774      Value *RHS = ICI->getOperand(1);
2775      switch (ICI->getPredicate()) {
2776      case ICmpInst::ICMP_SLT:
2777      case ICmpInst::ICMP_SLE:
2778        std::swap(LHS, RHS);
2779        // fall through
2780      case ICmpInst::ICMP_SGT:
2781      case ICmpInst::ICMP_SGE:
2782        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2783          return getSMaxExpr(getSCEV(LHS), getSCEV(RHS));
2784        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2785          return getSMinExpr(getSCEV(LHS), getSCEV(RHS));
2786        break;
2787      case ICmpInst::ICMP_ULT:
2788      case ICmpInst::ICMP_ULE:
2789        std::swap(LHS, RHS);
2790        // fall through
2791      case ICmpInst::ICMP_UGT:
2792      case ICmpInst::ICMP_UGE:
2793        if (LHS == U->getOperand(1) && RHS == U->getOperand(2))
2794          return getUMaxExpr(getSCEV(LHS), getSCEV(RHS));
2795        else if (LHS == U->getOperand(2) && RHS == U->getOperand(1))
2796          return getUMinExpr(getSCEV(LHS), getSCEV(RHS));
2797        break;
2798      case ICmpInst::ICMP_NE:
2799        // n != 0 ? n : 1  ->  umax(n, 1)
2800        if (LHS == U->getOperand(1) &&
2801            isa<ConstantInt>(U->getOperand(2)) &&
2802            cast<ConstantInt>(U->getOperand(2))->isOne() &&
2803            isa<ConstantInt>(RHS) &&
2804            cast<ConstantInt>(RHS)->isZero())
2805          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(2)));
2806        break;
2807      case ICmpInst::ICMP_EQ:
2808        // n == 0 ? 1 : n  ->  umax(n, 1)
2809        if (LHS == U->getOperand(2) &&
2810            isa<ConstantInt>(U->getOperand(1)) &&
2811            cast<ConstantInt>(U->getOperand(1))->isOne() &&
2812            isa<ConstantInt>(RHS) &&
2813            cast<ConstantInt>(RHS)->isZero())
2814          return getUMaxExpr(getSCEV(LHS), getSCEV(U->getOperand(1)));
2815        break;
2816      default:
2817        break;
2818      }
2819    }
2820
2821  default: // We cannot analyze this expression.
2822    break;
2823  }
2824
2825  return getUnknown(V);
2826}
2827
2828
2829
2830//===----------------------------------------------------------------------===//
2831//                   Iteration Count Computation Code
2832//
2833
2834/// getBackedgeTakenCount - If the specified loop has a predictable
2835/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute
2836/// object. The backedge-taken count is the number of times the loop header
2837/// will be branched to from within the loop. This is one less than the
2838/// trip count of the loop, since it doesn't count the first iteration,
2839/// when the header is branched to from outside the loop.
2840///
2841/// Note that it is not valid to call this method on a loop without a
2842/// loop-invariant backedge-taken count (see
2843/// hasLoopInvariantBackedgeTakenCount).
2844///
2845const SCEV* ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
2846  return getBackedgeTakenInfo(L).Exact;
2847}
2848
2849/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except
2850/// return the least SCEV value that is known never to be less than the
2851/// actual backedge taken count.
2852const SCEV* ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {
2853  return getBackedgeTakenInfo(L).Max;
2854}
2855
2856const ScalarEvolution::BackedgeTakenInfo &
2857ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
2858  // Initially insert a CouldNotCompute for this loop. If the insertion
2859  // succeeds, procede to actually compute a backedge-taken count and
2860  // update the value. The temporary CouldNotCompute value tells SCEV
2861  // code elsewhere that it shouldn't attempt to request a new
2862  // backedge-taken count, which could result in infinite recursion.
2863  std::pair<std::map<const Loop*, BackedgeTakenInfo>::iterator, bool> Pair =
2864    BackedgeTakenCounts.insert(std::make_pair(L, getCouldNotCompute()));
2865  if (Pair.second) {
2866    BackedgeTakenInfo ItCount = ComputeBackedgeTakenCount(L);
2867    if (ItCount.Exact != getCouldNotCompute()) {
2868      assert(ItCount.Exact->isLoopInvariant(L) &&
2869             ItCount.Max->isLoopInvariant(L) &&
2870             "Computed trip count isn't loop invariant for loop!");
2871      ++NumTripCountsComputed;
2872
2873      // Update the value in the map.
2874      Pair.first->second = ItCount;
2875    } else {
2876      if (ItCount.Max != getCouldNotCompute())
2877        // Update the value in the map.
2878        Pair.first->second = ItCount;
2879      if (isa<PHINode>(L->getHeader()->begin()))
2880        // Only count loops that have phi nodes as not being computable.
2881        ++NumTripCountsNotComputed;
2882    }
2883
2884    // Now that we know more about the trip count for this loop, forget any
2885    // existing SCEV values for PHI nodes in this loop since they are only
2886    // conservative estimates made without the benefit
2887    // of trip count information.
2888    if (ItCount.hasAnyInfo())
2889      forgetLoopPHIs(L);
2890  }
2891  return Pair.first->second;
2892}
2893
2894/// forgetLoopBackedgeTakenCount - This method should be called by the
2895/// client when it has changed a loop in a way that may effect
2896/// ScalarEvolution's ability to compute a trip count, or if the loop
2897/// is deleted.
2898void ScalarEvolution::forgetLoopBackedgeTakenCount(const Loop *L) {
2899  BackedgeTakenCounts.erase(L);
2900  forgetLoopPHIs(L);
2901}
2902
2903/// forgetLoopPHIs - Delete the memoized SCEVs associated with the
2904/// PHI nodes in the given loop. This is used when the trip count of
2905/// the loop may have changed.
2906void ScalarEvolution::forgetLoopPHIs(const Loop *L) {
2907  BasicBlock *Header = L->getHeader();
2908
2909  // Push all Loop-header PHIs onto the Worklist stack, except those
2910  // that are presently represented via a SCEVUnknown. SCEVUnknown for
2911  // a PHI either means that it has an unrecognized structure, or it's
2912  // a PHI that's in the progress of being computed by createNodeForPHI.
2913  // In the former case, additional loop trip count information isn't
2914  // going to change anything. In the later case, createNodeForPHI will
2915  // perform the necessary updates on its own when it gets to that point.
2916  SmallVector<Instruction *, 16> Worklist;
2917  for (BasicBlock::iterator I = Header->begin();
2918       PHINode *PN = dyn_cast<PHINode>(I); ++I) {
2919    std::map<SCEVCallbackVH, const SCEV*>::iterator It =
2920      Scalars.find((Value*)I);
2921    if (It != Scalars.end() && !isa<SCEVUnknown>(It->second))
2922      Worklist.push_back(PN);
2923  }
2924
2925  while (!Worklist.empty()) {
2926    Instruction *I = Worklist.pop_back_val();
2927    if (Scalars.erase(I))
2928      for (Value::use_iterator UI = I->use_begin(), UE = I->use_end();
2929           UI != UE; ++UI)
2930        Worklist.push_back(cast<Instruction>(UI));
2931  }
2932}
2933
2934/// ComputeBackedgeTakenCount - Compute the number of times the backedge
2935/// of the specified loop will execute.
2936ScalarEvolution::BackedgeTakenInfo
2937ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) {
2938  SmallVector<BasicBlock*, 8> ExitingBlocks;
2939  L->getExitingBlocks(ExitingBlocks);
2940
2941  // Examine all exits and pick the most conservative values.
2942  const SCEV* BECount = getCouldNotCompute();
2943  const SCEV* MaxBECount = getCouldNotCompute();
2944  bool CouldNotComputeBECount = false;
2945  for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
2946    BackedgeTakenInfo NewBTI =
2947      ComputeBackedgeTakenCountFromExit(L, ExitingBlocks[i]);
2948
2949    if (NewBTI.Exact == getCouldNotCompute()) {
2950      // We couldn't compute an exact value for this exit, so
2951      // we won't be able to compute an exact value for the loop.
2952      CouldNotComputeBECount = true;
2953      BECount = getCouldNotCompute();
2954    } else if (!CouldNotComputeBECount) {
2955      if (BECount == getCouldNotCompute())
2956        BECount = NewBTI.Exact;
2957      else
2958        BECount = getUMinFromMismatchedTypes(BECount, NewBTI.Exact);
2959    }
2960    if (MaxBECount == getCouldNotCompute())
2961      MaxBECount = NewBTI.Max;
2962    else if (NewBTI.Max != getCouldNotCompute())
2963      MaxBECount = getUMinFromMismatchedTypes(MaxBECount, NewBTI.Max);
2964  }
2965
2966  return BackedgeTakenInfo(BECount, MaxBECount);
2967}
2968
2969/// ComputeBackedgeTakenCountFromExit - Compute the number of times the backedge
2970/// of the specified loop will execute if it exits via the specified block.
2971ScalarEvolution::BackedgeTakenInfo
2972ScalarEvolution::ComputeBackedgeTakenCountFromExit(const Loop *L,
2973                                                   BasicBlock *ExitingBlock) {
2974
2975  // Okay, we've chosen an exiting block.  See what condition causes us to
2976  // exit at this block.
2977  //
2978  // FIXME: we should be able to handle switch instructions (with a single exit)
2979  BranchInst *ExitBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
2980  if (ExitBr == 0) return getCouldNotCompute();
2981  assert(ExitBr->isConditional() && "If unconditional, it can't be in loop!");
2982
2983  // At this point, we know we have a conditional branch that determines whether
2984  // the loop is exited.  However, we don't know if the branch is executed each
2985  // time through the loop.  If not, then the execution count of the branch will
2986  // not be equal to the trip count of the loop.
2987  //
2988  // Currently we check for this by checking to see if the Exit branch goes to
2989  // the loop header.  If so, we know it will always execute the same number of
2990  // times as the loop.  We also handle the case where the exit block *is* the
2991  // loop header.  This is common for un-rotated loops.
2992  //
2993  // If both of those tests fail, walk up the unique predecessor chain to the
2994  // header, stopping if there is an edge that doesn't exit the loop. If the
2995  // header is reached, the execution count of the branch will be equal to the
2996  // trip count of the loop.
2997  //
2998  //  More extensive analysis could be done to handle more cases here.
2999  //
3000  if (ExitBr->getSuccessor(0) != L->getHeader() &&
3001      ExitBr->getSuccessor(1) != L->getHeader() &&
3002      ExitBr->getParent() != L->getHeader()) {
3003    // The simple checks failed, try climbing the unique predecessor chain
3004    // up to the header.
3005    bool Ok = false;
3006    for (BasicBlock *BB = ExitBr->getParent(); BB; ) {
3007      BasicBlock *Pred = BB->getUniquePredecessor();
3008      if (!Pred)
3009        return getCouldNotCompute();
3010      TerminatorInst *PredTerm = Pred->getTerminator();
3011      for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) {
3012        BasicBlock *PredSucc = PredTerm->getSuccessor(i);
3013        if (PredSucc == BB)
3014          continue;
3015        // If the predecessor has a successor that isn't BB and isn't
3016        // outside the loop, assume the worst.
3017        if (L->contains(PredSucc))
3018          return getCouldNotCompute();
3019      }
3020      if (Pred == L->getHeader()) {
3021        Ok = true;
3022        break;
3023      }
3024      BB = Pred;
3025    }
3026    if (!Ok)
3027      return getCouldNotCompute();
3028  }
3029
3030  // Procede to the next level to examine the exit condition expression.
3031  return ComputeBackedgeTakenCountFromExitCond(L, ExitBr->getCondition(),
3032                                               ExitBr->getSuccessor(0),
3033                                               ExitBr->getSuccessor(1));
3034}
3035
3036/// ComputeBackedgeTakenCountFromExitCond - Compute the number of times the
3037/// backedge of the specified loop will execute if its exit condition
3038/// were a conditional branch of ExitCond, TBB, and FBB.
3039ScalarEvolution::BackedgeTakenInfo
3040ScalarEvolution::ComputeBackedgeTakenCountFromExitCond(const Loop *L,
3041                                                       Value *ExitCond,
3042                                                       BasicBlock *TBB,
3043                                                       BasicBlock *FBB) {
3044  // Check if the controlling expression for this loop is an And or Or.
3045  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
3046    if (BO->getOpcode() == Instruction::And) {
3047      // Recurse on the operands of the and.
3048      BackedgeTakenInfo BTI0 =
3049        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3050      BackedgeTakenInfo BTI1 =
3051        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3052      const SCEV* BECount = getCouldNotCompute();
3053      const SCEV* MaxBECount = getCouldNotCompute();
3054      if (L->contains(TBB)) {
3055        // Both conditions must be true for the loop to continue executing.
3056        // Choose the less conservative count.
3057        if (BTI0.Exact == getCouldNotCompute() ||
3058            BTI1.Exact == getCouldNotCompute())
3059          BECount = getCouldNotCompute();
3060        else
3061          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3062        if (BTI0.Max == getCouldNotCompute())
3063          MaxBECount = BTI1.Max;
3064        else if (BTI1.Max == getCouldNotCompute())
3065          MaxBECount = BTI0.Max;
3066        else
3067          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3068      } else {
3069        // Both conditions must be true for the loop to exit.
3070        assert(L->contains(FBB) && "Loop block has no successor in loop!");
3071        if (BTI0.Exact != getCouldNotCompute() &&
3072            BTI1.Exact != getCouldNotCompute())
3073          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3074        if (BTI0.Max != getCouldNotCompute() &&
3075            BTI1.Max != getCouldNotCompute())
3076          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3077      }
3078
3079      return BackedgeTakenInfo(BECount, MaxBECount);
3080    }
3081    if (BO->getOpcode() == Instruction::Or) {
3082      // Recurse on the operands of the or.
3083      BackedgeTakenInfo BTI0 =
3084        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(0), TBB, FBB);
3085      BackedgeTakenInfo BTI1 =
3086        ComputeBackedgeTakenCountFromExitCond(L, BO->getOperand(1), TBB, FBB);
3087      const SCEV* BECount = getCouldNotCompute();
3088      const SCEV* MaxBECount = getCouldNotCompute();
3089      if (L->contains(FBB)) {
3090        // Both conditions must be false for the loop to continue executing.
3091        // Choose the less conservative count.
3092        if (BTI0.Exact == getCouldNotCompute() ||
3093            BTI1.Exact == getCouldNotCompute())
3094          BECount = getCouldNotCompute();
3095        else
3096          BECount = getUMinFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3097        if (BTI0.Max == getCouldNotCompute())
3098          MaxBECount = BTI1.Max;
3099        else if (BTI1.Max == getCouldNotCompute())
3100          MaxBECount = BTI0.Max;
3101        else
3102          MaxBECount = getUMinFromMismatchedTypes(BTI0.Max, BTI1.Max);
3103      } else {
3104        // Both conditions must be false for the loop to exit.
3105        assert(L->contains(TBB) && "Loop block has no successor in loop!");
3106        if (BTI0.Exact != getCouldNotCompute() &&
3107            BTI1.Exact != getCouldNotCompute())
3108          BECount = getUMaxFromMismatchedTypes(BTI0.Exact, BTI1.Exact);
3109        if (BTI0.Max != getCouldNotCompute() &&
3110            BTI1.Max != getCouldNotCompute())
3111          MaxBECount = getUMaxFromMismatchedTypes(BTI0.Max, BTI1.Max);
3112      }
3113
3114      return BackedgeTakenInfo(BECount, MaxBECount);
3115    }
3116  }
3117
3118  // With an icmp, it may be feasible to compute an exact backedge-taken count.
3119  // Procede to the next level to examine the icmp.
3120  if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond))
3121    return ComputeBackedgeTakenCountFromExitCondICmp(L, ExitCondICmp, TBB, FBB);
3122
3123  // If it's not an integer or pointer comparison then compute it the hard way.
3124  return ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3125}
3126
3127/// ComputeBackedgeTakenCountFromExitCondICmp - Compute the number of times the
3128/// backedge of the specified loop will execute if its exit condition
3129/// were a conditional branch of the ICmpInst ExitCond, TBB, and FBB.
3130ScalarEvolution::BackedgeTakenInfo
3131ScalarEvolution::ComputeBackedgeTakenCountFromExitCondICmp(const Loop *L,
3132                                                           ICmpInst *ExitCond,
3133                                                           BasicBlock *TBB,
3134                                                           BasicBlock *FBB) {
3135
3136  // If the condition was exit on true, convert the condition to exit on false
3137  ICmpInst::Predicate Cond;
3138  if (!L->contains(FBB))
3139    Cond = ExitCond->getPredicate();
3140  else
3141    Cond = ExitCond->getInversePredicate();
3142
3143  // Handle common loops like: for (X = "string"; *X; ++X)
3144  if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))
3145    if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {
3146      const SCEV* ItCnt =
3147        ComputeLoadConstantCompareBackedgeTakenCount(LI, RHS, L, Cond);
3148      if (!isa<SCEVCouldNotCompute>(ItCnt)) {
3149        unsigned BitWidth = getTypeSizeInBits(ItCnt->getType());
3150        return BackedgeTakenInfo(ItCnt,
3151                                 isa<SCEVConstant>(ItCnt) ? ItCnt :
3152                                   getConstant(APInt::getMaxValue(BitWidth)-1));
3153      }
3154    }
3155
3156  const SCEV* LHS = getSCEV(ExitCond->getOperand(0));
3157  const SCEV* RHS = getSCEV(ExitCond->getOperand(1));
3158
3159  // Try to evaluate any dependencies out of the loop.
3160  LHS = getSCEVAtScope(LHS, L);
3161  RHS = getSCEVAtScope(RHS, L);
3162
3163  // At this point, we would like to compute how many iterations of the
3164  // loop the predicate will return true for these inputs.
3165  if (LHS->isLoopInvariant(L) && !RHS->isLoopInvariant(L)) {
3166    // If there is a loop-invariant, force it into the RHS.
3167    std::swap(LHS, RHS);
3168    Cond = ICmpInst::getSwappedPredicate(Cond);
3169  }
3170
3171  // If we have a comparison of a chrec against a constant, try to use value
3172  // ranges to answer this query.
3173  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS))
3174    if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))
3175      if (AddRec->getLoop() == L) {
3176        // Form the constant range.
3177        ConstantRange CompRange(
3178            ICmpInst::makeConstantRange(Cond, RHSC->getValue()->getValue()));
3179
3180        const SCEV* Ret = AddRec->getNumIterationsInRange(CompRange, *this);
3181        if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;
3182      }
3183
3184  switch (Cond) {
3185  case ICmpInst::ICMP_NE: {                     // while (X != Y)
3186    // Convert to: while (X-Y != 0)
3187    const SCEV* TC = HowFarToZero(getMinusSCEV(LHS, RHS), L);
3188    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
3189    break;
3190  }
3191  case ICmpInst::ICMP_EQ: {
3192    // Convert to: while (X-Y == 0)           // while (X == Y)
3193    const SCEV* TC = HowFarToNonZero(getMinusSCEV(LHS, RHS), L);
3194    if (!isa<SCEVCouldNotCompute>(TC)) return TC;
3195    break;
3196  }
3197  case ICmpInst::ICMP_SLT: {
3198    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, true);
3199    if (BTI.hasAnyInfo()) return BTI;
3200    break;
3201  }
3202  case ICmpInst::ICMP_SGT: {
3203    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3204                                             getNotSCEV(RHS), L, true);
3205    if (BTI.hasAnyInfo()) return BTI;
3206    break;
3207  }
3208  case ICmpInst::ICMP_ULT: {
3209    BackedgeTakenInfo BTI = HowManyLessThans(LHS, RHS, L, false);
3210    if (BTI.hasAnyInfo()) return BTI;
3211    break;
3212  }
3213  case ICmpInst::ICMP_UGT: {
3214    BackedgeTakenInfo BTI = HowManyLessThans(getNotSCEV(LHS),
3215                                             getNotSCEV(RHS), L, false);
3216    if (BTI.hasAnyInfo()) return BTI;
3217    break;
3218  }
3219  default:
3220#if 0
3221    errs() << "ComputeBackedgeTakenCount ";
3222    if (ExitCond->getOperand(0)->getType()->isUnsigned())
3223      errs() << "[unsigned] ";
3224    errs() << *LHS << "   "
3225         << Instruction::getOpcodeName(Instruction::ICmp)
3226         << "   " << *RHS << "\n";
3227#endif
3228    break;
3229  }
3230  return
3231    ComputeBackedgeTakenCountExhaustively(L, ExitCond, !L->contains(TBB));
3232}
3233
3234static ConstantInt *
3235EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C,
3236                                ScalarEvolution &SE) {
3237  const SCEV* InVal = SE.getConstant(C);
3238  const SCEV* Val = AddRec->evaluateAtIteration(InVal, SE);
3239  assert(isa<SCEVConstant>(Val) &&
3240         "Evaluation of SCEV at constant didn't fold correctly?");
3241  return cast<SCEVConstant>(Val)->getValue();
3242}
3243
3244/// GetAddressedElementFromGlobal - Given a global variable with an initializer
3245/// and a GEP expression (missing the pointer index) indexing into it, return
3246/// the addressed element of the initializer or null if the index expression is
3247/// invalid.
3248static Constant *
3249GetAddressedElementFromGlobal(GlobalVariable *GV,
3250                              const std::vector<ConstantInt*> &Indices) {
3251  Constant *Init = GV->getInitializer();
3252  for (unsigned i = 0, e = Indices.size(); i != e; ++i) {
3253    uint64_t Idx = Indices[i]->getZExtValue();
3254    if (ConstantStruct *CS = dyn_cast<ConstantStruct>(Init)) {
3255      assert(Idx < CS->getNumOperands() && "Bad struct index!");
3256      Init = cast<Constant>(CS->getOperand(Idx));
3257    } else if (ConstantArray *CA = dyn_cast<ConstantArray>(Init)) {
3258      if (Idx >= CA->getNumOperands()) return 0;  // Bogus program
3259      Init = cast<Constant>(CA->getOperand(Idx));
3260    } else if (isa<ConstantAggregateZero>(Init)) {
3261      if (const StructType *STy = dyn_cast<StructType>(Init->getType())) {
3262        assert(Idx < STy->getNumElements() && "Bad struct index!");
3263        Init = Constant::getNullValue(STy->getElementType(Idx));
3264      } else if (const ArrayType *ATy = dyn_cast<ArrayType>(Init->getType())) {
3265        if (Idx >= ATy->getNumElements()) return 0;  // Bogus program
3266        Init = Constant::getNullValue(ATy->getElementType());
3267      } else {
3268        assert(0 && "Unknown constant aggregate type!");
3269      }
3270      return 0;
3271    } else {
3272      return 0; // Unknown initializer type
3273    }
3274  }
3275  return Init;
3276}
3277
3278/// ComputeLoadConstantCompareBackedgeTakenCount - Given an exit condition of
3279/// 'icmp op load X, cst', try to see if we can compute the backedge
3280/// execution count.
3281const SCEV *
3282ScalarEvolution::ComputeLoadConstantCompareBackedgeTakenCount(
3283                                                LoadInst *LI,
3284                                                Constant *RHS,
3285                                                const Loop *L,
3286                                                ICmpInst::Predicate predicate) {
3287  if (LI->isVolatile()) return getCouldNotCompute();
3288
3289  // Check to see if the loaded pointer is a getelementptr of a global.
3290  GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0));
3291  if (!GEP) return getCouldNotCompute();
3292
3293  // Make sure that it is really a constant global we are gepping, with an
3294  // initializer, and make sure the first IDX is really 0.
3295  GlobalVariable *GV = dyn_cast<GlobalVariable>(GEP->getOperand(0));
3296  if (!GV || !GV->isConstant() || !GV->hasInitializer() ||
3297      GEP->getNumOperands() < 3 || !isa<Constant>(GEP->getOperand(1)) ||
3298      !cast<Constant>(GEP->getOperand(1))->isNullValue())
3299    return getCouldNotCompute();
3300
3301  // Okay, we allow one non-constant index into the GEP instruction.
3302  Value *VarIdx = 0;
3303  std::vector<ConstantInt*> Indexes;
3304  unsigned VarIdxNum = 0;
3305  for (unsigned i = 2, e = GEP->getNumOperands(); i != e; ++i)
3306    if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
3307      Indexes.push_back(CI);
3308    } else if (!isa<ConstantInt>(GEP->getOperand(i))) {
3309      if (VarIdx) return getCouldNotCompute();  // Multiple non-constant idx's.
3310      VarIdx = GEP->getOperand(i);
3311      VarIdxNum = i-2;
3312      Indexes.push_back(0);
3313    }
3314
3315  // Okay, we know we have a (load (gep GV, 0, X)) comparison with a constant.
3316  // Check to see if X is a loop variant variable value now.
3317  const SCEV* Idx = getSCEV(VarIdx);
3318  Idx = getSCEVAtScope(Idx, L);
3319
3320  // We can only recognize very limited forms of loop index expressions, in
3321  // particular, only affine AddRec's like {C1,+,C2}.
3322  const SCEVAddRecExpr *IdxExpr = dyn_cast<SCEVAddRecExpr>(Idx);
3323  if (!IdxExpr || !IdxExpr->isAffine() || IdxExpr->isLoopInvariant(L) ||
3324      !isa<SCEVConstant>(IdxExpr->getOperand(0)) ||
3325      !isa<SCEVConstant>(IdxExpr->getOperand(1)))
3326    return getCouldNotCompute();
3327
3328  unsigned MaxSteps = MaxBruteForceIterations;
3329  for (unsigned IterationNum = 0; IterationNum != MaxSteps; ++IterationNum) {
3330    ConstantInt *ItCst =
3331      ConstantInt::get(cast<IntegerType>(IdxExpr->getType()), IterationNum);
3332    ConstantInt *Val = EvaluateConstantChrecAtConstant(IdxExpr, ItCst, *this);
3333
3334    // Form the GEP offset.
3335    Indexes[VarIdxNum] = Val;
3336
3337    Constant *Result = GetAddressedElementFromGlobal(GV, Indexes);
3338    if (Result == 0) break;  // Cannot compute!
3339
3340    // Evaluate the condition for this iteration.
3341    Result = ConstantExpr::getICmp(predicate, Result, RHS);
3342    if (!isa<ConstantInt>(Result)) break;  // Couldn't decide for sure
3343    if (cast<ConstantInt>(Result)->getValue().isMinValue()) {
3344#if 0
3345      errs() << "\n***\n*** Computed loop count " << *ItCst
3346             << "\n*** From global " << *GV << "*** BB: " << *L->getHeader()
3347             << "***\n";
3348#endif
3349      ++NumArrayLenItCounts;
3350      return getConstant(ItCst);   // Found terminating iteration!
3351    }
3352  }
3353  return getCouldNotCompute();
3354}
3355
3356
3357/// CanConstantFold - Return true if we can constant fold an instruction of the
3358/// specified type, assuming that all operands were constants.
3359static bool CanConstantFold(const Instruction *I) {
3360  if (isa<BinaryOperator>(I) || isa<CmpInst>(I) ||
3361      isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I))
3362    return true;
3363
3364  if (const CallInst *CI = dyn_cast<CallInst>(I))
3365    if (const Function *F = CI->getCalledFunction())
3366      return canConstantFoldCallTo(F);
3367  return false;
3368}
3369
3370/// getConstantEvolvingPHI - Given an LLVM value and a loop, return a PHI node
3371/// in the loop that V is derived from.  We allow arbitrary operations along the
3372/// way, but the operands of an operation must either be constants or a value
3373/// derived from a constant PHI.  If this expression does not fit with these
3374/// constraints, return null.
3375static PHINode *getConstantEvolvingPHI(Value *V, const Loop *L) {
3376  // If this is not an instruction, or if this is an instruction outside of the
3377  // loop, it can't be derived from a loop PHI.
3378  Instruction *I = dyn_cast<Instruction>(V);
3379  if (I == 0 || !L->contains(I->getParent())) return 0;
3380
3381  if (PHINode *PN = dyn_cast<PHINode>(I)) {
3382    if (L->getHeader() == I->getParent())
3383      return PN;
3384    else
3385      // We don't currently keep track of the control flow needed to evaluate
3386      // PHIs, so we cannot handle PHIs inside of loops.
3387      return 0;
3388  }
3389
3390  // If we won't be able to constant fold this expression even if the operands
3391  // are constants, return early.
3392  if (!CanConstantFold(I)) return 0;
3393
3394  // Otherwise, we can evaluate this instruction if all of its operands are
3395  // constant or derived from a PHI node themselves.
3396  PHINode *PHI = 0;
3397  for (unsigned Op = 0, e = I->getNumOperands(); Op != e; ++Op)
3398    if (!(isa<Constant>(I->getOperand(Op)) ||
3399          isa<GlobalValue>(I->getOperand(Op)))) {
3400      PHINode *P = getConstantEvolvingPHI(I->getOperand(Op), L);
3401      if (P == 0) return 0;  // Not evolving from PHI
3402      if (PHI == 0)
3403        PHI = P;
3404      else if (PHI != P)
3405        return 0;  // Evolving from multiple different PHIs.
3406    }
3407
3408  // This is a expression evolving from a constant PHI!
3409  return PHI;
3410}
3411
3412/// EvaluateExpression - Given an expression that passes the
3413/// getConstantEvolvingPHI predicate, evaluate its value assuming the PHI node
3414/// in the loop has the value PHIVal.  If we can't fold this expression for some
3415/// reason, return null.
3416static Constant *EvaluateExpression(Value *V, Constant *PHIVal) {
3417  if (isa<PHINode>(V)) return PHIVal;
3418  if (Constant *C = dyn_cast<Constant>(V)) return C;
3419  if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) return GV;
3420  Instruction *I = cast<Instruction>(V);
3421
3422  std::vector<Constant*> Operands;
3423  Operands.resize(I->getNumOperands());
3424
3425  for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
3426    Operands[i] = EvaluateExpression(I->getOperand(i), PHIVal);
3427    if (Operands[i] == 0) return 0;
3428  }
3429
3430  if (const CmpInst *CI = dyn_cast<CmpInst>(I))
3431    return ConstantFoldCompareInstOperands(CI->getPredicate(),
3432                                           &Operands[0], Operands.size());
3433  else
3434    return ConstantFoldInstOperands(I->getOpcode(), I->getType(),
3435                                    &Operands[0], Operands.size());
3436}
3437
3438/// getConstantEvolutionLoopExitValue - If we know that the specified Phi is
3439/// in the header of its containing loop, we know the loop executes a
3440/// constant number of times, and the PHI node is just a recurrence
3441/// involving constants, fold it.
3442Constant *
3443ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
3444                                                   const APInt& BEs,
3445                                                   const Loop *L) {
3446  std::map<PHINode*, Constant*>::iterator I =
3447    ConstantEvolutionLoopExitValue.find(PN);
3448  if (I != ConstantEvolutionLoopExitValue.end())
3449    return I->second;
3450
3451  if (BEs.ugt(APInt(BEs.getBitWidth(),MaxBruteForceIterations)))
3452    return ConstantEvolutionLoopExitValue[PN] = 0;  // Not going to evaluate it.
3453
3454  Constant *&RetVal = ConstantEvolutionLoopExitValue[PN];
3455
3456  // Since the loop is canonicalized, the PHI node must have two entries.  One
3457  // entry must be a constant (coming in from outside of the loop), and the
3458  // second must be derived from the same PHI.
3459  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
3460  Constant *StartCST =
3461    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
3462  if (StartCST == 0)
3463    return RetVal = 0;  // Must be a constant.
3464
3465  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
3466  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
3467  if (PN2 != PN)
3468    return RetVal = 0;  // Not derived from same PHI.
3469
3470  // Execute the loop symbolically to determine the exit value.
3471  if (BEs.getActiveBits() >= 32)
3472    return RetVal = 0; // More than 2^32-1 iterations?? Not doing it!
3473
3474  unsigned NumIterations = BEs.getZExtValue(); // must be in range
3475  unsigned IterationNum = 0;
3476  for (Constant *PHIVal = StartCST; ; ++IterationNum) {
3477    if (IterationNum == NumIterations)
3478      return RetVal = PHIVal;  // Got exit value!
3479
3480    // Compute the value of the PHI node for the next iteration.
3481    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
3482    if (NextPHI == PHIVal)
3483      return RetVal = NextPHI;  // Stopped evolving!
3484    if (NextPHI == 0)
3485      return 0;        // Couldn't evaluate!
3486    PHIVal = NextPHI;
3487  }
3488}
3489
3490/// ComputeBackedgeTakenCountExhaustively - If the trip is known to execute a
3491/// constant number of times (the condition evolves only from constants),
3492/// try to evaluate a few iterations of the loop until we get the exit
3493/// condition gets a value of ExitWhen (true or false).  If we cannot
3494/// evaluate the trip count of the loop, return getCouldNotCompute().
3495const SCEV *
3496ScalarEvolution::ComputeBackedgeTakenCountExhaustively(const Loop *L,
3497                                                       Value *Cond,
3498                                                       bool ExitWhen) {
3499  PHINode *PN = getConstantEvolvingPHI(Cond, L);
3500  if (PN == 0) return getCouldNotCompute();
3501
3502  // Since the loop is canonicalized, the PHI node must have two entries.  One
3503  // entry must be a constant (coming in from outside of the loop), and the
3504  // second must be derived from the same PHI.
3505  bool SecondIsBackedge = L->contains(PN->getIncomingBlock(1));
3506  Constant *StartCST =
3507    dyn_cast<Constant>(PN->getIncomingValue(!SecondIsBackedge));
3508  if (StartCST == 0) return getCouldNotCompute();  // Must be a constant.
3509
3510  Value *BEValue = PN->getIncomingValue(SecondIsBackedge);
3511  PHINode *PN2 = getConstantEvolvingPHI(BEValue, L);
3512  if (PN2 != PN) return getCouldNotCompute();  // Not derived from same PHI.
3513
3514  // Okay, we find a PHI node that defines the trip count of this loop.  Execute
3515  // the loop symbolically to determine when the condition gets a value of
3516  // "ExitWhen".
3517  unsigned IterationNum = 0;
3518  unsigned MaxIterations = MaxBruteForceIterations;   // Limit analysis.
3519  for (Constant *PHIVal = StartCST;
3520       IterationNum != MaxIterations; ++IterationNum) {
3521    ConstantInt *CondVal =
3522      dyn_cast_or_null<ConstantInt>(EvaluateExpression(Cond, PHIVal));
3523
3524    // Couldn't symbolically evaluate.
3525    if (!CondVal) return getCouldNotCompute();
3526
3527    if (CondVal->getValue() == uint64_t(ExitWhen)) {
3528      ++NumBruteForceTripCountsComputed;
3529      return getConstant(Type::Int32Ty, IterationNum);
3530    }
3531
3532    // Compute the value of the PHI node for the next iteration.
3533    Constant *NextPHI = EvaluateExpression(BEValue, PHIVal);
3534    if (NextPHI == 0 || NextPHI == PHIVal)
3535      return getCouldNotCompute();// Couldn't evaluate or not making progress...
3536    PHIVal = NextPHI;
3537  }
3538
3539  // Too many iterations were needed to evaluate.
3540  return getCouldNotCompute();
3541}
3542
3543/// getSCEVAtScope - Return a SCEV expression handle for the specified value
3544/// at the specified scope in the program.  The L value specifies a loop
3545/// nest to evaluate the expression at, where null is the top-level or a
3546/// specified loop is immediately inside of the loop.
3547///
3548/// This method can be used to compute the exit value for a variable defined
3549/// in a loop by querying what the value will hold in the parent loop.
3550///
3551/// In the case that a relevant loop exit value cannot be computed, the
3552/// original value V is returned.
3553const SCEV* ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
3554  // FIXME: this should be turned into a virtual method on SCEV!
3555
3556  if (isa<SCEVConstant>(V)) return V;
3557
3558  // If this instruction is evolved from a constant-evolving PHI, compute the
3559  // exit value from the loop without using SCEVs.
3560  if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
3561    if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
3562      const Loop *LI = (*this->LI)[I->getParent()];
3563      if (LI && LI->getParentLoop() == L)  // Looking for loop exit value.
3564        if (PHINode *PN = dyn_cast<PHINode>(I))
3565          if (PN->getParent() == LI->getHeader()) {
3566            // Okay, there is no closed form solution for the PHI node.  Check
3567            // to see if the loop that contains it has a known backedge-taken
3568            // count.  If so, we may be able to force computation of the exit
3569            // value.
3570            const SCEV* BackedgeTakenCount = getBackedgeTakenCount(LI);
3571            if (const SCEVConstant *BTCC =
3572                  dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
3573              // Okay, we know how many times the containing loop executes.  If
3574              // this is a constant evolving PHI node, get the final value at
3575              // the specified iteration number.
3576              Constant *RV = getConstantEvolutionLoopExitValue(PN,
3577                                                   BTCC->getValue()->getValue(),
3578                                                               LI);
3579              if (RV) return getSCEV(RV);
3580            }
3581          }
3582
3583      // Okay, this is an expression that we cannot symbolically evaluate
3584      // into a SCEV.  Check to see if it's possible to symbolically evaluate
3585      // the arguments into constants, and if so, try to constant propagate the
3586      // result.  This is particularly useful for computing loop exit values.
3587      if (CanConstantFold(I)) {
3588        // Check to see if we've folded this instruction at this loop before.
3589        std::map<const Loop *, Constant *> &Values = ValuesAtScopes[I];
3590        std::pair<std::map<const Loop *, Constant *>::iterator, bool> Pair =
3591          Values.insert(std::make_pair(L, static_cast<Constant *>(0)));
3592        if (!Pair.second)
3593          return Pair.first->second ? &*getSCEV(Pair.first->second) : V;
3594
3595        std::vector<Constant*> Operands;
3596        Operands.reserve(I->getNumOperands());
3597        for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) {
3598          Value *Op = I->getOperand(i);
3599          if (Constant *C = dyn_cast<Constant>(Op)) {
3600            Operands.push_back(C);
3601          } else {
3602            // If any of the operands is non-constant and if they are
3603            // non-integer and non-pointer, don't even try to analyze them
3604            // with scev techniques.
3605            if (!isSCEVable(Op->getType()))
3606              return V;
3607
3608            const SCEV* OpV = getSCEVAtScope(getSCEV(Op), L);
3609            if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(OpV)) {
3610              Constant *C = SC->getValue();
3611              if (C->getType() != Op->getType())
3612                C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
3613                                                                  Op->getType(),
3614                                                                  false),
3615                                          C, Op->getType());
3616              Operands.push_back(C);
3617            } else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(OpV)) {
3618              if (Constant *C = dyn_cast<Constant>(SU->getValue())) {
3619                if (C->getType() != Op->getType())
3620                  C =
3621                    ConstantExpr::getCast(CastInst::getCastOpcode(C, false,
3622                                                                  Op->getType(),
3623                                                                  false),
3624                                          C, Op->getType());
3625                Operands.push_back(C);
3626              } else
3627                return V;
3628            } else {
3629              return V;
3630            }
3631          }
3632        }
3633
3634        Constant *C;
3635        if (const CmpInst *CI = dyn_cast<CmpInst>(I))
3636          C = ConstantFoldCompareInstOperands(CI->getPredicate(),
3637                                              &Operands[0], Operands.size());
3638        else
3639          C = ConstantFoldInstOperands(I->getOpcode(), I->getType(),
3640                                       &Operands[0], Operands.size());
3641        Pair.first->second = C;
3642        return getSCEV(C);
3643      }
3644    }
3645
3646    // This is some other type of SCEVUnknown, just return it.
3647    return V;
3648  }
3649
3650  if (const SCEVCommutativeExpr *Comm = dyn_cast<SCEVCommutativeExpr>(V)) {
3651    // Avoid performing the look-up in the common case where the specified
3652    // expression has no loop-variant portions.
3653    for (unsigned i = 0, e = Comm->getNumOperands(); i != e; ++i) {
3654      const SCEV* OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3655      if (OpAtScope != Comm->getOperand(i)) {
3656        // Okay, at least one of these operands is loop variant but might be
3657        // foldable.  Build a new instance of the folded commutative expression.
3658        SmallVector<const SCEV *, 8> NewOps(Comm->op_begin(),
3659                                            Comm->op_begin()+i);
3660        NewOps.push_back(OpAtScope);
3661
3662        for (++i; i != e; ++i) {
3663          OpAtScope = getSCEVAtScope(Comm->getOperand(i), L);
3664          NewOps.push_back(OpAtScope);
3665        }
3666        if (isa<SCEVAddExpr>(Comm))
3667          return getAddExpr(NewOps);
3668        if (isa<SCEVMulExpr>(Comm))
3669          return getMulExpr(NewOps);
3670        if (isa<SCEVSMaxExpr>(Comm))
3671          return getSMaxExpr(NewOps);
3672        if (isa<SCEVUMaxExpr>(Comm))
3673          return getUMaxExpr(NewOps);
3674        assert(0 && "Unknown commutative SCEV type!");
3675      }
3676    }
3677    // If we got here, all operands are loop invariant.
3678    return Comm;
3679  }
3680
3681  if (const SCEVUDivExpr *Div = dyn_cast<SCEVUDivExpr>(V)) {
3682    const SCEV* LHS = getSCEVAtScope(Div->getLHS(), L);
3683    const SCEV* RHS = getSCEVAtScope(Div->getRHS(), L);
3684    if (LHS == Div->getLHS() && RHS == Div->getRHS())
3685      return Div;   // must be loop invariant
3686    return getUDivExpr(LHS, RHS);
3687  }
3688
3689  // If this is a loop recurrence for a loop that does not contain L, then we
3690  // are dealing with the final value computed by the loop.
3691  if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V)) {
3692    if (!L || !AddRec->getLoop()->contains(L->getHeader())) {
3693      // To evaluate this recurrence, we need to know how many times the AddRec
3694      // loop iterates.  Compute this now.
3695      const SCEV* BackedgeTakenCount = getBackedgeTakenCount(AddRec->getLoop());
3696      if (BackedgeTakenCount == getCouldNotCompute()) return AddRec;
3697
3698      // Then, evaluate the AddRec.
3699      return AddRec->evaluateAtIteration(BackedgeTakenCount, *this);
3700    }
3701    return AddRec;
3702  }
3703
3704  if (const SCEVZeroExtendExpr *Cast = dyn_cast<SCEVZeroExtendExpr>(V)) {
3705    const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L);
3706    if (Op == Cast->getOperand())
3707      return Cast;  // must be loop invariant
3708    return getZeroExtendExpr(Op, Cast->getType());
3709  }
3710
3711  if (const SCEVSignExtendExpr *Cast = dyn_cast<SCEVSignExtendExpr>(V)) {
3712    const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L);
3713    if (Op == Cast->getOperand())
3714      return Cast;  // must be loop invariant
3715    return getSignExtendExpr(Op, Cast->getType());
3716  }
3717
3718  if (const SCEVTruncateExpr *Cast = dyn_cast<SCEVTruncateExpr>(V)) {
3719    const SCEV* Op = getSCEVAtScope(Cast->getOperand(), L);
3720    if (Op == Cast->getOperand())
3721      return Cast;  // must be loop invariant
3722    return getTruncateExpr(Op, Cast->getType());
3723  }
3724
3725  assert(0 && "Unknown SCEV type!");
3726  return 0;
3727}
3728
3729/// getSCEVAtScope - This is a convenience function which does
3730/// getSCEVAtScope(getSCEV(V), L).
3731const SCEV* ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
3732  return getSCEVAtScope(getSCEV(V), L);
3733}
3734
3735/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the
3736/// following equation:
3737///
3738///     A * X = B (mod N)
3739///
3740/// where N = 2^BW and BW is the common bit width of A and B. The signedness of
3741/// A and B isn't important.
3742///
3743/// If the equation does not have a solution, SCEVCouldNotCompute is returned.
3744static const SCEV* SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
3745                                               ScalarEvolution &SE) {
3746  uint32_t BW = A.getBitWidth();
3747  assert(BW == B.getBitWidth() && "Bit widths must be the same.");
3748  assert(A != 0 && "A must be non-zero.");
3749
3750  // 1. D = gcd(A, N)
3751  //
3752  // The gcd of A and N may have only one prime factor: 2. The number of
3753  // trailing zeros in A is its multiplicity
3754  uint32_t Mult2 = A.countTrailingZeros();
3755  // D = 2^Mult2
3756
3757  // 2. Check if B is divisible by D.
3758  //
3759  // B is divisible by D if and only if the multiplicity of prime factor 2 for B
3760  // is not less than multiplicity of this prime factor for D.
3761  if (B.countTrailingZeros() < Mult2)
3762    return SE.getCouldNotCompute();
3763
3764  // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
3765  // modulo (N / D).
3766  //
3767  // (N / D) may need BW+1 bits in its representation.  Hence, we'll use this
3768  // bit width during computations.
3769  APInt AD = A.lshr(Mult2).zext(BW + 1);  // AD = A / D
3770  APInt Mod(BW + 1, 0);
3771  Mod.set(BW - Mult2);  // Mod = N / D
3772  APInt I = AD.multiplicativeInverse(Mod);
3773
3774  // 4. Compute the minimum unsigned root of the equation:
3775  // I * (B / D) mod (N / D)
3776  APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
3777
3778  // The result is guaranteed to be less than 2^BW so we may truncate it to BW
3779  // bits.
3780  return SE.getConstant(Result.trunc(BW));
3781}
3782
3783/// SolveQuadraticEquation - Find the roots of the quadratic equation for the
3784/// given quadratic chrec {L,+,M,+,N}.  This returns either the two roots (which
3785/// might be the same) or two SCEVCouldNotCompute objects.
3786///
3787static std::pair<const SCEV*,const SCEV*>
3788SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
3789  assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
3790  const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
3791  const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
3792  const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
3793
3794  // We currently can only solve this if the coefficients are constants.
3795  if (!LC || !MC || !NC) {
3796    const SCEV *CNC = SE.getCouldNotCompute();
3797    return std::make_pair(CNC, CNC);
3798  }
3799
3800  uint32_t BitWidth = LC->getValue()->getValue().getBitWidth();
3801  const APInt &L = LC->getValue()->getValue();
3802  const APInt &M = MC->getValue()->getValue();
3803  const APInt &N = NC->getValue()->getValue();
3804  APInt Two(BitWidth, 2);
3805  APInt Four(BitWidth, 4);
3806
3807  {
3808    using namespace APIntOps;
3809    const APInt& C = L;
3810    // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
3811    // The B coefficient is M-N/2
3812    APInt B(M);
3813    B -= sdiv(N,Two);
3814
3815    // The A coefficient is N/2
3816    APInt A(N.sdiv(Two));
3817
3818    // Compute the B^2-4ac term.
3819    APInt SqrtTerm(B);
3820    SqrtTerm *= B;
3821    SqrtTerm -= Four * (A * C);
3822
3823    // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
3824    // integer value or else APInt::sqrt() will assert.
3825    APInt SqrtVal(SqrtTerm.sqrt());
3826
3827    // Compute the two solutions for the quadratic formula.
3828    // The divisions must be performed as signed divisions.
3829    APInt NegB(-B);
3830    APInt TwoA( A << 1 );
3831    if (TwoA.isMinValue()) {
3832      const SCEV *CNC = SE.getCouldNotCompute();
3833      return std::make_pair(CNC, CNC);
3834    }
3835
3836    ConstantInt *Solution1 = ConstantInt::get((NegB + SqrtVal).sdiv(TwoA));
3837    ConstantInt *Solution2 = ConstantInt::get((NegB - SqrtVal).sdiv(TwoA));
3838
3839    return std::make_pair(SE.getConstant(Solution1),
3840                          SE.getConstant(Solution2));
3841    } // end APIntOps namespace
3842}
3843
3844/// HowFarToZero - Return the number of times a backedge comparing the specified
3845/// value to zero will execute.  If not computable, return CouldNotCompute.
3846const SCEV* ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L) {
3847  // If the value is a constant
3848  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3849    // If the value is already zero, the branch will execute zero times.
3850    if (C->getValue()->isZero()) return C;
3851    return getCouldNotCompute();  // Otherwise it will loop infinitely.
3852  }
3853
3854  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
3855  if (!AddRec || AddRec->getLoop() != L)
3856    return getCouldNotCompute();
3857
3858  if (AddRec->isAffine()) {
3859    // If this is an affine expression, the execution count of this branch is
3860    // the minimum unsigned root of the following equation:
3861    //
3862    //     Start + Step*N = 0 (mod 2^BW)
3863    //
3864    // equivalent to:
3865    //
3866    //             Step*N = -Start (mod 2^BW)
3867    //
3868    // where BW is the common bit width of Start and Step.
3869
3870    // Get the initial value for the loop.
3871    const SCEV *Start = getSCEVAtScope(AddRec->getStart(),
3872                                       L->getParentLoop());
3873    const SCEV *Step = getSCEVAtScope(AddRec->getOperand(1),
3874                                      L->getParentLoop());
3875
3876    if (const SCEVConstant *StepC = dyn_cast<SCEVConstant>(Step)) {
3877      // For now we handle only constant steps.
3878
3879      // First, handle unitary steps.
3880      if (StepC->getValue()->equalsInt(1))      // 1*N = -Start (mod 2^BW), so:
3881        return getNegativeSCEV(Start);       //   N = -Start (as unsigned)
3882      if (StepC->getValue()->isAllOnesValue())  // -1*N = -Start (mod 2^BW), so:
3883        return Start;                           //    N = Start (as unsigned)
3884
3885      // Then, try to solve the above equation provided that Start is constant.
3886      if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start))
3887        return SolveLinEquationWithOverflow(StepC->getValue()->getValue(),
3888                                            -StartC->getValue()->getValue(),
3889                                            *this);
3890    }
3891  } else if (AddRec->isQuadratic() && AddRec->getType()->isInteger()) {
3892    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
3893    // the quadratic equation to solve it.
3894    std::pair<const SCEV*,const SCEV*> Roots = SolveQuadraticEquation(AddRec,
3895                                                                    *this);
3896    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
3897    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
3898    if (R1) {
3899#if 0
3900      errs() << "HFTZ: " << *V << " - sol#1: " << *R1
3901             << "  sol#2: " << *R2 << "\n";
3902#endif
3903      // Pick the smallest positive root value.
3904      if (ConstantInt *CB =
3905          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
3906                                   R1->getValue(), R2->getValue()))) {
3907        if (CB->getZExtValue() == false)
3908          std::swap(R1, R2);   // R1 is the minimum root now.
3909
3910        // We can only use this value if the chrec ends up with an exact zero
3911        // value at this index.  When solving for "X*X != 5", for example, we
3912        // should not accept a root of 2.
3913        const SCEV* Val = AddRec->evaluateAtIteration(R1, *this);
3914        if (Val->isZero())
3915          return R1;  // We found a quadratic root!
3916      }
3917    }
3918  }
3919
3920  return getCouldNotCompute();
3921}
3922
3923/// HowFarToNonZero - Return the number of times a backedge checking the
3924/// specified value for nonzero will execute.  If not computable, return
3925/// CouldNotCompute
3926const SCEV* ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) {
3927  // Loops that look like: while (X == 0) are very strange indeed.  We don't
3928  // handle them yet except for the trivial case.  This could be expanded in the
3929  // future as needed.
3930
3931  // If the value is a constant, check to see if it is known to be non-zero
3932  // already.  If so, the backedge will execute zero times.
3933  if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {
3934    if (!C->getValue()->isNullValue())
3935      return getIntegerSCEV(0, C->getType());
3936    return getCouldNotCompute();  // Otherwise it will loop infinitely.
3937  }
3938
3939  // We could implement others, but I really doubt anyone writes loops like
3940  // this, and if they did, they would already be constant folded.
3941  return getCouldNotCompute();
3942}
3943
3944/// getLoopPredecessor - If the given loop's header has exactly one unique
3945/// predecessor outside the loop, return it. Otherwise return null.
3946///
3947BasicBlock *ScalarEvolution::getLoopPredecessor(const Loop *L) {
3948  BasicBlock *Header = L->getHeader();
3949  BasicBlock *Pred = 0;
3950  for (pred_iterator PI = pred_begin(Header), E = pred_end(Header);
3951       PI != E; ++PI)
3952    if (!L->contains(*PI)) {
3953      if (Pred && Pred != *PI) return 0; // Multiple predecessors.
3954      Pred = *PI;
3955    }
3956  return Pred;
3957}
3958
3959/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB
3960/// (which may not be an immediate predecessor) which has exactly one
3961/// successor from which BB is reachable, or null if no such block is
3962/// found.
3963///
3964BasicBlock *
3965ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
3966  // If the block has a unique predecessor, then there is no path from the
3967  // predecessor to the block that does not go through the direct edge
3968  // from the predecessor to the block.
3969  if (BasicBlock *Pred = BB->getSinglePredecessor())
3970    return Pred;
3971
3972  // A loop's header is defined to be a block that dominates the loop.
3973  // If the header has a unique predecessor outside the loop, it must be
3974  // a block that has exactly one successor that can reach the loop.
3975  if (Loop *L = LI->getLoopFor(BB))
3976    return getLoopPredecessor(L);
3977
3978  return 0;
3979}
3980
3981/// HasSameValue - SCEV structural equivalence is usually sufficient for
3982/// testing whether two expressions are equal, however for the purposes of
3983/// looking for a condition guarding a loop, it can be useful to be a little
3984/// more general, since a front-end may have replicated the controlling
3985/// expression.
3986///
3987static bool HasSameValue(const SCEV* A, const SCEV* B) {
3988  // Quick check to see if they are the same SCEV.
3989  if (A == B) return true;
3990
3991  // Otherwise, if they're both SCEVUnknown, it's possible that they hold
3992  // two different instructions with the same value. Check for this case.
3993  if (const SCEVUnknown *AU = dyn_cast<SCEVUnknown>(A))
3994    if (const SCEVUnknown *BU = dyn_cast<SCEVUnknown>(B))
3995      if (const Instruction *AI = dyn_cast<Instruction>(AU->getValue()))
3996        if (const Instruction *BI = dyn_cast<Instruction>(BU->getValue()))
3997          if (AI->isIdenticalTo(BI))
3998            return true;
3999
4000  // Otherwise assume they may have a different value.
4001  return false;
4002}
4003
4004/// isLoopGuardedByCond - Test whether entry to the loop is protected by
4005/// a conditional between LHS and RHS.  This is used to help avoid max
4006/// expressions in loop trip counts.
4007bool ScalarEvolution::isLoopGuardedByCond(const Loop *L,
4008                                          ICmpInst::Predicate Pred,
4009                                          const SCEV *LHS, const SCEV *RHS) {
4010  // Interpret a null as meaning no loop, where there is obviously no guard
4011  // (interprocedural conditions notwithstanding).
4012  if (!L) return false;
4013
4014  BasicBlock *Predecessor = getLoopPredecessor(L);
4015  BasicBlock *PredecessorDest = L->getHeader();
4016
4017  // Starting at the loop predecessor, climb up the predecessor chain, as long
4018  // as there are predecessors that can be found that have unique successors
4019  // leading to the original header.
4020  for (; Predecessor;
4021       PredecessorDest = Predecessor,
4022       Predecessor = getPredecessorWithUniqueSuccessorForBB(Predecessor)) {
4023
4024    BranchInst *LoopEntryPredicate =
4025      dyn_cast<BranchInst>(Predecessor->getTerminator());
4026    if (!LoopEntryPredicate ||
4027        LoopEntryPredicate->isUnconditional())
4028      continue;
4029
4030    if (isNecessaryCond(LoopEntryPredicate->getCondition(), Pred, LHS, RHS,
4031                        LoopEntryPredicate->getSuccessor(0) != PredecessorDest))
4032      return true;
4033  }
4034
4035  return false;
4036}
4037
4038/// isNecessaryCond - Test whether the given CondValue value is a condition
4039/// which is at least as strict as the one described by Pred, LHS, and RHS.
4040bool ScalarEvolution::isNecessaryCond(Value *CondValue,
4041                                      ICmpInst::Predicate Pred,
4042                                      const SCEV *LHS, const SCEV *RHS,
4043                                      bool Inverse) {
4044  // Recursivly handle And and Or conditions.
4045  if (BinaryOperator *BO = dyn_cast<BinaryOperator>(CondValue)) {
4046    if (BO->getOpcode() == Instruction::And) {
4047      if (!Inverse)
4048        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4049               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4050    } else if (BO->getOpcode() == Instruction::Or) {
4051      if (Inverse)
4052        return isNecessaryCond(BO->getOperand(0), Pred, LHS, RHS, Inverse) ||
4053               isNecessaryCond(BO->getOperand(1), Pred, LHS, RHS, Inverse);
4054    }
4055  }
4056
4057  ICmpInst *ICI = dyn_cast<ICmpInst>(CondValue);
4058  if (!ICI) return false;
4059
4060  // Now that we found a conditional branch that dominates the loop, check to
4061  // see if it is the comparison we are looking for.
4062  Value *PreCondLHS = ICI->getOperand(0);
4063  Value *PreCondRHS = ICI->getOperand(1);
4064  ICmpInst::Predicate Cond;
4065  if (Inverse)
4066    Cond = ICI->getInversePredicate();
4067  else
4068    Cond = ICI->getPredicate();
4069
4070  if (Cond == Pred)
4071    ; // An exact match.
4072  else if (!ICmpInst::isTrueWhenEqual(Cond) && Pred == ICmpInst::ICMP_NE)
4073    ; // The actual condition is beyond sufficient.
4074  else
4075    // Check a few special cases.
4076    switch (Cond) {
4077    case ICmpInst::ICMP_UGT:
4078      if (Pred == ICmpInst::ICMP_ULT) {
4079        std::swap(PreCondLHS, PreCondRHS);
4080        Cond = ICmpInst::ICMP_ULT;
4081        break;
4082      }
4083      return false;
4084    case ICmpInst::ICMP_SGT:
4085      if (Pred == ICmpInst::ICMP_SLT) {
4086        std::swap(PreCondLHS, PreCondRHS);
4087        Cond = ICmpInst::ICMP_SLT;
4088        break;
4089      }
4090      return false;
4091    case ICmpInst::ICMP_NE:
4092      // Expressions like (x >u 0) are often canonicalized to (x != 0),
4093      // so check for this case by checking if the NE is comparing against
4094      // a minimum or maximum constant.
4095      if (!ICmpInst::isTrueWhenEqual(Pred))
4096        if (ConstantInt *CI = dyn_cast<ConstantInt>(PreCondRHS)) {
4097          const APInt &A = CI->getValue();
4098          switch (Pred) {
4099          case ICmpInst::ICMP_SLT:
4100            if (A.isMaxSignedValue()) break;
4101            return false;
4102          case ICmpInst::ICMP_SGT:
4103            if (A.isMinSignedValue()) break;
4104            return false;
4105          case ICmpInst::ICMP_ULT:
4106            if (A.isMaxValue()) break;
4107            return false;
4108          case ICmpInst::ICMP_UGT:
4109            if (A.isMinValue()) break;
4110            return false;
4111          default:
4112            return false;
4113          }
4114          Cond = ICmpInst::ICMP_NE;
4115          // NE is symmetric but the original comparison may not be. Swap
4116          // the operands if necessary so that they match below.
4117          if (isa<SCEVConstant>(LHS))
4118            std::swap(PreCondLHS, PreCondRHS);
4119          break;
4120        }
4121      return false;
4122    default:
4123      // We weren't able to reconcile the condition.
4124      return false;
4125    }
4126
4127  if (!PreCondLHS->getType()->isInteger()) return false;
4128
4129  const SCEV *PreCondLHSSCEV = getSCEV(PreCondLHS);
4130  const SCEV *PreCondRHSSCEV = getSCEV(PreCondRHS);
4131  return (HasSameValue(LHS, PreCondLHSSCEV) &&
4132          HasSameValue(RHS, PreCondRHSSCEV)) ||
4133         (HasSameValue(LHS, getNotSCEV(PreCondRHSSCEV)) &&
4134          HasSameValue(RHS, getNotSCEV(PreCondLHSSCEV)));
4135}
4136
4137/// getBECount - Subtract the end and start values and divide by the step,
4138/// rounding up, to get the number of times the backedge is executed. Return
4139/// CouldNotCompute if an intermediate computation overflows.
4140const SCEV* ScalarEvolution::getBECount(const SCEV* Start,
4141                                       const SCEV* End,
4142                                       const SCEV* Step) {
4143  const Type *Ty = Start->getType();
4144  const SCEV* NegOne = getIntegerSCEV(-1, Ty);
4145  const SCEV* Diff = getMinusSCEV(End, Start);
4146  const SCEV* RoundUp = getAddExpr(Step, NegOne);
4147
4148  // Add an adjustment to the difference between End and Start so that
4149  // the division will effectively round up.
4150  const SCEV* Add = getAddExpr(Diff, RoundUp);
4151
4152  // Check Add for unsigned overflow.
4153  // TODO: More sophisticated things could be done here.
4154  const Type *WideTy = IntegerType::get(getTypeSizeInBits(Ty) + 1);
4155  const SCEV* OperandExtendedAdd =
4156    getAddExpr(getZeroExtendExpr(Diff, WideTy),
4157               getZeroExtendExpr(RoundUp, WideTy));
4158  if (getZeroExtendExpr(Add, WideTy) != OperandExtendedAdd)
4159    return getCouldNotCompute();
4160
4161  return getUDivExpr(Add, Step);
4162}
4163
4164/// HowManyLessThans - Return the number of times a backedge containing the
4165/// specified less-than comparison will execute.  If not computable, return
4166/// CouldNotCompute.
4167ScalarEvolution::BackedgeTakenInfo
4168ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS,
4169                                  const Loop *L, bool isSigned) {
4170  // Only handle:  "ADDREC < LoopInvariant".
4171  if (!RHS->isLoopInvariant(L)) return getCouldNotCompute();
4172
4173  const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS);
4174  if (!AddRec || AddRec->getLoop() != L)
4175    return getCouldNotCompute();
4176
4177  if (AddRec->isAffine()) {
4178    // FORNOW: We only support unit strides.
4179    unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
4180    const SCEV* Step = AddRec->getStepRecurrence(*this);
4181
4182    // TODO: handle non-constant strides.
4183    const SCEVConstant *CStep = dyn_cast<SCEVConstant>(Step);
4184    if (!CStep || CStep->isZero())
4185      return getCouldNotCompute();
4186    if (CStep->isOne()) {
4187      // With unit stride, the iteration never steps past the limit value.
4188    } else if (CStep->getValue()->getValue().isStrictlyPositive()) {
4189      if (const SCEVConstant *CLimit = dyn_cast<SCEVConstant>(RHS)) {
4190        // Test whether a positive iteration iteration can step past the limit
4191        // value and past the maximum value for its type in a single step.
4192        if (isSigned) {
4193          APInt Max = APInt::getSignedMaxValue(BitWidth);
4194          if ((Max - CStep->getValue()->getValue())
4195                .slt(CLimit->getValue()->getValue()))
4196            return getCouldNotCompute();
4197        } else {
4198          APInt Max = APInt::getMaxValue(BitWidth);
4199          if ((Max - CStep->getValue()->getValue())
4200                .ult(CLimit->getValue()->getValue()))
4201            return getCouldNotCompute();
4202        }
4203      } else
4204        // TODO: handle non-constant limit values below.
4205        return getCouldNotCompute();
4206    } else
4207      // TODO: handle negative strides below.
4208      return getCouldNotCompute();
4209
4210    // We know the LHS is of the form {n,+,s} and the RHS is some loop-invariant
4211    // m.  So, we count the number of iterations in which {n,+,s} < m is true.
4212    // Note that we cannot simply return max(m-n,0)/s because it's not safe to
4213    // treat m-n as signed nor unsigned due to overflow possibility.
4214
4215    // First, we get the value of the LHS in the first iteration: n
4216    const SCEV* Start = AddRec->getOperand(0);
4217
4218    // Determine the minimum constant start value.
4219    const SCEV *MinStart = isa<SCEVConstant>(Start) ? Start :
4220      getConstant(isSigned ? APInt::getSignedMinValue(BitWidth) :
4221                             APInt::getMinValue(BitWidth));
4222
4223    // If we know that the condition is true in order to enter the loop,
4224    // then we know that it will run exactly (m-n)/s times. Otherwise, we
4225    // only know that it will execute (max(m,n)-n)/s times. In both cases,
4226    // the division must round up.
4227    const SCEV* End = RHS;
4228    if (!isLoopGuardedByCond(L,
4229                             isSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT,
4230                             getMinusSCEV(Start, Step), RHS))
4231      End = isSigned ? getSMaxExpr(RHS, Start)
4232                     : getUMaxExpr(RHS, Start);
4233
4234    // Determine the maximum constant end value.
4235    const SCEV* MaxEnd =
4236      isa<SCEVConstant>(End) ? End :
4237      getConstant(isSigned ? APInt::getSignedMaxValue(BitWidth)
4238                               .ashr(GetMinSignBits(End) - 1) :
4239                             APInt::getMaxValue(BitWidth)
4240                               .lshr(GetMinLeadingZeros(End)));
4241
4242    // Finally, we subtract these two values and divide, rounding up, to get
4243    // the number of times the backedge is executed.
4244    const SCEV* BECount = getBECount(Start, End, Step);
4245
4246    // The maximum backedge count is similar, except using the minimum start
4247    // value and the maximum end value.
4248    const SCEV* MaxBECount = getBECount(MinStart, MaxEnd, Step);
4249
4250    return BackedgeTakenInfo(BECount, MaxBECount);
4251  }
4252
4253  return getCouldNotCompute();
4254}
4255
4256/// getNumIterationsInRange - Return the number of iterations of this loop that
4257/// produce values in the specified constant range.  Another way of looking at
4258/// this is that it returns the first iteration number where the value is not in
4259/// the condition, thus computing the exit count. If the iteration count can't
4260/// be computed, an instance of SCEVCouldNotCompute is returned.
4261const SCEV* SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range,
4262                                                    ScalarEvolution &SE) const {
4263  if (Range.isFullSet())  // Infinite loop.
4264    return SE.getCouldNotCompute();
4265
4266  // If the start is a non-zero constant, shift the range to simplify things.
4267  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
4268    if (!SC->getValue()->isZero()) {
4269      SmallVector<const SCEV*, 4> Operands(op_begin(), op_end());
4270      Operands[0] = SE.getIntegerSCEV(0, SC->getType());
4271      const SCEV* Shifted = SE.getAddRecExpr(Operands, getLoop());
4272      if (const SCEVAddRecExpr *ShiftedAddRec =
4273            dyn_cast<SCEVAddRecExpr>(Shifted))
4274        return ShiftedAddRec->getNumIterationsInRange(
4275                           Range.subtract(SC->getValue()->getValue()), SE);
4276      // This is strange and shouldn't happen.
4277      return SE.getCouldNotCompute();
4278    }
4279
4280  // The only time we can solve this is when we have all constant indices.
4281  // Otherwise, we cannot determine the overflow conditions.
4282  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
4283    if (!isa<SCEVConstant>(getOperand(i)))
4284      return SE.getCouldNotCompute();
4285
4286
4287  // Okay at this point we know that all elements of the chrec are constants and
4288  // that the start element is zero.
4289
4290  // First check to see if the range contains zero.  If not, the first
4291  // iteration exits.
4292  unsigned BitWidth = SE.getTypeSizeInBits(getType());
4293  if (!Range.contains(APInt(BitWidth, 0)))
4294    return SE.getIntegerSCEV(0, getType());
4295
4296  if (isAffine()) {
4297    // If this is an affine expression then we have this situation:
4298    //   Solve {0,+,A} in Range  ===  Ax in Range
4299
4300    // We know that zero is in the range.  If A is positive then we know that
4301    // the upper value of the range must be the first possible exit value.
4302    // If A is negative then the lower of the range is the last possible loop
4303    // value.  Also note that we already checked for a full range.
4304    APInt One(BitWidth,1);
4305    APInt A     = cast<SCEVConstant>(getOperand(1))->getValue()->getValue();
4306    APInt End = A.sge(One) ? (Range.getUpper() - One) : Range.getLower();
4307
4308    // The exit value should be (End+A)/A.
4309    APInt ExitVal = (End + A).udiv(A);
4310    ConstantInt *ExitValue = ConstantInt::get(ExitVal);
4311
4312    // Evaluate at the exit value.  If we really did fall out of the valid
4313    // range, then we computed our trip count, otherwise wrap around or other
4314    // things must have happened.
4315    ConstantInt *Val = EvaluateConstantChrecAtConstant(this, ExitValue, SE);
4316    if (Range.contains(Val->getValue()))
4317      return SE.getCouldNotCompute();  // Something strange happened
4318
4319    // Ensure that the previous value is in the range.  This is a sanity check.
4320    assert(Range.contains(
4321           EvaluateConstantChrecAtConstant(this,
4322           ConstantInt::get(ExitVal - One), SE)->getValue()) &&
4323           "Linear scev computation is off in a bad way!");
4324    return SE.getConstant(ExitValue);
4325  } else if (isQuadratic()) {
4326    // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
4327    // quadratic equation to solve it.  To do this, we must frame our problem in
4328    // terms of figuring out when zero is crossed, instead of when
4329    // Range.getUpper() is crossed.
4330    SmallVector<const SCEV*, 4> NewOps(op_begin(), op_end());
4331    NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
4332    const SCEV* NewAddRec = SE.getAddRecExpr(NewOps, getLoop());
4333
4334    // Next, solve the constructed addrec
4335    std::pair<const SCEV*,const SCEV*> Roots =
4336      SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE);
4337    const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first);
4338    const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second);
4339    if (R1) {
4340      // Pick the smallest positive root value.
4341      if (ConstantInt *CB =
4342          dyn_cast<ConstantInt>(ConstantExpr::getICmp(ICmpInst::ICMP_ULT,
4343                                   R1->getValue(), R2->getValue()))) {
4344        if (CB->getZExtValue() == false)
4345          std::swap(R1, R2);   // R1 is the minimum root now.
4346
4347        // Make sure the root is not off by one.  The returned iteration should
4348        // not be in the range, but the previous one should be.  When solving
4349        // for "X*X < 5", for example, we should not return a root of 2.
4350        ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this,
4351                                                             R1->getValue(),
4352                                                             SE);
4353        if (Range.contains(R1Val->getValue())) {
4354          // The next iteration must be out of the range...
4355          ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()+1);
4356
4357          R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
4358          if (!Range.contains(R1Val->getValue()))
4359            return SE.getConstant(NextVal);
4360          return SE.getCouldNotCompute();  // Something strange happened
4361        }
4362
4363        // If R1 was not in the range, then it is a good return value.  Make
4364        // sure that R1-1 WAS in the range though, just in case.
4365        ConstantInt *NextVal = ConstantInt::get(R1->getValue()->getValue()-1);
4366        R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
4367        if (Range.contains(R1Val->getValue()))
4368          return R1;
4369        return SE.getCouldNotCompute();  // Something strange happened
4370      }
4371    }
4372  }
4373
4374  return SE.getCouldNotCompute();
4375}
4376
4377
4378
4379//===----------------------------------------------------------------------===//
4380//                   SCEVCallbackVH Class Implementation
4381//===----------------------------------------------------------------------===//
4382
4383void ScalarEvolution::SCEVCallbackVH::deleted() {
4384  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
4385  if (PHINode *PN = dyn_cast<PHINode>(getValPtr()))
4386    SE->ConstantEvolutionLoopExitValue.erase(PN);
4387  if (Instruction *I = dyn_cast<Instruction>(getValPtr()))
4388    SE->ValuesAtScopes.erase(I);
4389  SE->Scalars.erase(getValPtr());
4390  // this now dangles!
4391}
4392
4393void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *) {
4394  assert(SE && "SCEVCallbackVH called with a non-null ScalarEvolution!");
4395
4396  // Forget all the expressions associated with users of the old value,
4397  // so that future queries will recompute the expressions using the new
4398  // value.
4399  SmallVector<User *, 16> Worklist;
4400  Value *Old = getValPtr();
4401  bool DeleteOld = false;
4402  for (Value::use_iterator UI = Old->use_begin(), UE = Old->use_end();
4403       UI != UE; ++UI)
4404    Worklist.push_back(*UI);
4405  while (!Worklist.empty()) {
4406    User *U = Worklist.pop_back_val();
4407    // Deleting the Old value will cause this to dangle. Postpone
4408    // that until everything else is done.
4409    if (U == Old) {
4410      DeleteOld = true;
4411      continue;
4412    }
4413    if (PHINode *PN = dyn_cast<PHINode>(U))
4414      SE->ConstantEvolutionLoopExitValue.erase(PN);
4415    if (Instruction *I = dyn_cast<Instruction>(U))
4416      SE->ValuesAtScopes.erase(I);
4417    if (SE->Scalars.erase(U))
4418      for (Value::use_iterator UI = U->use_begin(), UE = U->use_end();
4419           UI != UE; ++UI)
4420        Worklist.push_back(*UI);
4421  }
4422  if (DeleteOld) {
4423    if (PHINode *PN = dyn_cast<PHINode>(Old))
4424      SE->ConstantEvolutionLoopExitValue.erase(PN);
4425    if (Instruction *I = dyn_cast<Instruction>(Old))
4426      SE->ValuesAtScopes.erase(I);
4427    SE->Scalars.erase(Old);
4428    // this now dangles!
4429  }
4430  // this may dangle!
4431}
4432
4433ScalarEvolution::SCEVCallbackVH::SCEVCallbackVH(Value *V, ScalarEvolution *se)
4434  : CallbackVH(V), SE(se) {}
4435
4436//===----------------------------------------------------------------------===//
4437//                   ScalarEvolution Class Implementation
4438//===----------------------------------------------------------------------===//
4439
4440ScalarEvolution::ScalarEvolution()
4441  : FunctionPass(&ID) {
4442}
4443
4444bool ScalarEvolution::runOnFunction(Function &F) {
4445  this->F = &F;
4446  LI = &getAnalysis<LoopInfo>();
4447  TD = getAnalysisIfAvailable<TargetData>();
4448  return false;
4449}
4450
4451void ScalarEvolution::releaseMemory() {
4452  Scalars.clear();
4453  BackedgeTakenCounts.clear();
4454  ConstantEvolutionLoopExitValue.clear();
4455  ValuesAtScopes.clear();
4456  UniqueSCEVs.clear();
4457  SCEVAllocator.Reset();
4458}
4459
4460void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const {
4461  AU.setPreservesAll();
4462  AU.addRequiredTransitive<LoopInfo>();
4463}
4464
4465bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) {
4466  return !isa<SCEVCouldNotCompute>(getBackedgeTakenCount(L));
4467}
4468
4469static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,
4470                          const Loop *L) {
4471  // Print all inner loops first
4472  for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I)
4473    PrintLoopInfo(OS, SE, *I);
4474
4475  OS << "Loop " << L->getHeader()->getName() << ": ";
4476
4477  SmallVector<BasicBlock*, 8> ExitBlocks;
4478  L->getExitBlocks(ExitBlocks);
4479  if (ExitBlocks.size() != 1)
4480    OS << "<multiple exits> ";
4481
4482  if (SE->hasLoopInvariantBackedgeTakenCount(L)) {
4483    OS << "backedge-taken count is " << *SE->getBackedgeTakenCount(L);
4484  } else {
4485    OS << "Unpredictable backedge-taken count. ";
4486  }
4487
4488  OS << "\n";
4489  OS << "Loop " << L->getHeader()->getName() << ": ";
4490
4491  if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {
4492    OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L);
4493  } else {
4494    OS << "Unpredictable max backedge-taken count. ";
4495  }
4496
4497  OS << "\n";
4498}
4499
4500void ScalarEvolution::print(raw_ostream &OS, const Module* ) const {
4501  // ScalarEvolution's implementaiton of the print method is to print
4502  // out SCEV values of all instructions that are interesting. Doing
4503  // this potentially causes it to create new SCEV objects though,
4504  // which technically conflicts with the const qualifier. This isn't
4505  // observable from outside the class though (the hasSCEV function
4506  // notwithstanding), so casting away the const isn't dangerous.
4507  ScalarEvolution &SE = *const_cast<ScalarEvolution*>(this);
4508
4509  OS << "Classifying expressions for: " << F->getName() << "\n";
4510  for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I)
4511    if (isSCEVable(I->getType())) {
4512      OS << *I;
4513      OS << "  -->  ";
4514      const SCEV* SV = SE.getSCEV(&*I);
4515      SV->print(OS);
4516
4517      const Loop *L = LI->getLoopFor((*I).getParent());
4518
4519      const SCEV* AtUse = SE.getSCEVAtScope(SV, L);
4520      if (AtUse != SV) {
4521        OS << "  -->  ";
4522        AtUse->print(OS);
4523      }
4524
4525      if (L) {
4526        OS << "\t\t" "Exits: ";
4527        const SCEV* ExitValue = SE.getSCEVAtScope(SV, L->getParentLoop());
4528        if (!ExitValue->isLoopInvariant(L)) {
4529          OS << "<<Unknown>>";
4530        } else {
4531          OS << *ExitValue;
4532        }
4533      }
4534
4535      OS << "\n";
4536    }
4537
4538  OS << "Determining loop execution counts for: " << F->getName() << "\n";
4539  for (LoopInfo::iterator I = LI->begin(), E = LI->end(); I != E; ++I)
4540    PrintLoopInfo(OS, &SE, *I);
4541}
4542
4543void ScalarEvolution::print(std::ostream &o, const Module *M) const {
4544  raw_os_ostream OS(o);
4545  print(OS, M);
4546}
4547