1//===- HexagonLoopIdiomRecognition.cpp ------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "HexagonLoopIdiomRecognition.h"
10#include "llvm/ADT/APInt.h"
11#include "llvm/ADT/DenseMap.h"
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallPtrSet.h"
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/ADT/SmallVector.h"
16#include "llvm/ADT/StringRef.h"
17#include "llvm/Analysis/AliasAnalysis.h"
18#include "llvm/Analysis/InstructionSimplify.h"
19#include "llvm/Analysis/LoopAnalysisManager.h"
20#include "llvm/Analysis/LoopInfo.h"
21#include "llvm/Analysis/LoopPass.h"
22#include "llvm/Analysis/MemoryLocation.h"
23#include "llvm/Analysis/ScalarEvolution.h"
24#include "llvm/Analysis/ScalarEvolutionExpressions.h"
25#include "llvm/Analysis/TargetLibraryInfo.h"
26#include "llvm/Analysis/ValueTracking.h"
27#include "llvm/IR/Attributes.h"
28#include "llvm/IR/BasicBlock.h"
29#include "llvm/IR/Constant.h"
30#include "llvm/IR/Constants.h"
31#include "llvm/IR/DataLayout.h"
32#include "llvm/IR/DebugLoc.h"
33#include "llvm/IR/DerivedTypes.h"
34#include "llvm/IR/Dominators.h"
35#include "llvm/IR/Function.h"
36#include "llvm/IR/IRBuilder.h"
37#include "llvm/IR/InstrTypes.h"
38#include "llvm/IR/Instruction.h"
39#include "llvm/IR/Instructions.h"
40#include "llvm/IR/IntrinsicInst.h"
41#include "llvm/IR/Intrinsics.h"
42#include "llvm/IR/IntrinsicsHexagon.h"
43#include "llvm/IR/Module.h"
44#include "llvm/IR/PassManager.h"
45#include "llvm/IR/PatternMatch.h"
46#include "llvm/IR/Type.h"
47#include "llvm/IR/User.h"
48#include "llvm/IR/Value.h"
49#include "llvm/InitializePasses.h"
50#include "llvm/Pass.h"
51#include "llvm/Support/Casting.h"
52#include "llvm/Support/CommandLine.h"
53#include "llvm/Support/Compiler.h"
54#include "llvm/Support/Debug.h"
55#include "llvm/Support/ErrorHandling.h"
56#include "llvm/Support/KnownBits.h"
57#include "llvm/Support/raw_ostream.h"
58#include "llvm/TargetParser/Triple.h"
59#include "llvm/Transforms/Scalar.h"
60#include "llvm/Transforms/Utils.h"
61#include "llvm/Transforms/Utils/Local.h"
62#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
63#include <algorithm>
64#include <array>
65#include <cassert>
66#include <cstdint>
67#include <cstdlib>
68#include <deque>
69#include <functional>
70#include <iterator>
71#include <map>
72#include <set>
73#include <utility>
74#include <vector>
75
76#define DEBUG_TYPE "hexagon-lir"
77
78using namespace llvm;
79
80static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom",
81  cl::Hidden, cl::init(false),
82  cl::desc("Disable generation of memcpy in loop idiom recognition"));
83
84static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom",
85  cl::Hidden, cl::init(false),
86  cl::desc("Disable generation of memmove in loop idiom recognition"));
87
88static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold",
89  cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime "
90  "check guarding the memmove."));
91
92static cl::opt<unsigned> CompileTimeMemSizeThreshold(
93  "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64),
94  cl::desc("Threshold (in bytes) to perform the transformation, if the "
95    "runtime loop count (mem transfer size) is known at compile-time."));
96
97static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom",
98  cl::Hidden, cl::init(true),
99  cl::desc("Only enable generating memmove in non-nested loops"));
100
101static cl::opt<bool> HexagonVolatileMemcpy(
102    "disable-hexagon-volatile-memcpy", cl::Hidden, cl::init(false),
103    cl::desc("Enable Hexagon-specific memcpy for volatile destination."));
104
105static cl::opt<unsigned> SimplifyLimit("hlir-simplify-limit", cl::init(10000),
106  cl::Hidden, cl::desc("Maximum number of simplification steps in HLIR"));
107
108static const char *HexagonVolatileMemcpyName
109  = "hexagon_memcpy_forward_vp4cp4n2";
110
111
112namespace llvm {
113
114void initializeHexagonLoopIdiomRecognizeLegacyPassPass(PassRegistry &);
115Pass *createHexagonLoopIdiomPass();
116
117} // end namespace llvm
118
119namespace {
120
121class HexagonLoopIdiomRecognize {
122public:
123  explicit HexagonLoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT,
124                                     LoopInfo *LF, const TargetLibraryInfo *TLI,
125                                     ScalarEvolution *SE)
126      : AA(AA), DT(DT), LF(LF), TLI(TLI), SE(SE) {}
127
128  bool run(Loop *L);
129
130private:
131  int getSCEVStride(const SCEVAddRecExpr *StoreEv);
132  bool isLegalStore(Loop *CurLoop, StoreInst *SI);
133  void collectStores(Loop *CurLoop, BasicBlock *BB,
134                     SmallVectorImpl<StoreInst *> &Stores);
135  bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount);
136  bool coverLoop(Loop *L, SmallVectorImpl<Instruction *> &Insts) const;
137  bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount,
138                      SmallVectorImpl<BasicBlock *> &ExitBlocks);
139  bool runOnCountableLoop(Loop *L);
140
141  AliasAnalysis *AA;
142  const DataLayout *DL;
143  DominatorTree *DT;
144  LoopInfo *LF;
145  const TargetLibraryInfo *TLI;
146  ScalarEvolution *SE;
147  bool HasMemcpy, HasMemmove;
148};
149
150class HexagonLoopIdiomRecognizeLegacyPass : public LoopPass {
151public:
152  static char ID;
153
154  explicit HexagonLoopIdiomRecognizeLegacyPass() : LoopPass(ID) {
155    initializeHexagonLoopIdiomRecognizeLegacyPassPass(
156        *PassRegistry::getPassRegistry());
157  }
158
159  StringRef getPassName() const override {
160    return "Recognize Hexagon-specific loop idioms";
161  }
162
163  void getAnalysisUsage(AnalysisUsage &AU) const override {
164    AU.addRequired<LoopInfoWrapperPass>();
165    AU.addRequiredID(LoopSimplifyID);
166    AU.addRequiredID(LCSSAID);
167    AU.addRequired<AAResultsWrapperPass>();
168    AU.addRequired<ScalarEvolutionWrapperPass>();
169    AU.addRequired<DominatorTreeWrapperPass>();
170    AU.addRequired<TargetLibraryInfoWrapperPass>();
171    AU.addPreserved<TargetLibraryInfoWrapperPass>();
172  }
173
174  bool runOnLoop(Loop *L, LPPassManager &LPM) override;
175};
176
177struct Simplifier {
178  struct Rule {
179    using FuncType = std::function<Value *(Instruction *, LLVMContext &)>;
180    Rule(StringRef N, FuncType F) : Name(N), Fn(F) {}
181    StringRef Name; // For debugging.
182    FuncType Fn;
183  };
184
185  void addRule(StringRef N, const Rule::FuncType &F) {
186    Rules.push_back(Rule(N, F));
187  }
188
189private:
190  struct WorkListType {
191    WorkListType() = default;
192
193    void push_back(Value *V) {
194      // Do not push back duplicates.
195      if (S.insert(V).second)
196        Q.push_back(V);
197    }
198
199    Value *pop_front_val() {
200      Value *V = Q.front();
201      Q.pop_front();
202      S.erase(V);
203      return V;
204    }
205
206    bool empty() const { return Q.empty(); }
207
208  private:
209    std::deque<Value *> Q;
210    std::set<Value *> S;
211  };
212
213  using ValueSetType = std::set<Value *>;
214
215  std::vector<Rule> Rules;
216
217public:
218  struct Context {
219    using ValueMapType = DenseMap<Value *, Value *>;
220
221    Value *Root;
222    ValueSetType Used;   // The set of all cloned values used by Root.
223    ValueSetType Clones; // The set of all cloned values.
224    LLVMContext &Ctx;
225
226    Context(Instruction *Exp)
227        : Ctx(Exp->getParent()->getParent()->getContext()) {
228      initialize(Exp);
229    }
230
231    ~Context() { cleanup(); }
232
233    void print(raw_ostream &OS, const Value *V) const;
234    Value *materialize(BasicBlock *B, BasicBlock::iterator At);
235
236  private:
237    friend struct Simplifier;
238
239    void initialize(Instruction *Exp);
240    void cleanup();
241
242    template <typename FuncT> void traverse(Value *V, FuncT F);
243    void record(Value *V);
244    void use(Value *V);
245    void unuse(Value *V);
246
247    bool equal(const Instruction *I, const Instruction *J) const;
248    Value *find(Value *Tree, Value *Sub) const;
249    Value *subst(Value *Tree, Value *OldV, Value *NewV);
250    void replace(Value *OldV, Value *NewV);
251    void link(Instruction *I, BasicBlock *B, BasicBlock::iterator At);
252  };
253
254  Value *simplify(Context &C);
255};
256
257  struct PE {
258    PE(const Simplifier::Context &c, Value *v = nullptr) : C(c), V(v) {}
259
260    const Simplifier::Context &C;
261    const Value *V;
262  };
263
264  LLVM_ATTRIBUTE_USED
265  raw_ostream &operator<<(raw_ostream &OS, const PE &P) {
266    P.C.print(OS, P.V ? P.V : P.C.Root);
267    return OS;
268  }
269
270} // end anonymous namespace
271
272char HexagonLoopIdiomRecognizeLegacyPass::ID = 0;
273
274INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom",
275                      "Recognize Hexagon-specific loop idioms", false, false)
276INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
277INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
278INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
279INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
280INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
281INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
282INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
283INITIALIZE_PASS_END(HexagonLoopIdiomRecognizeLegacyPass, "hexagon-loop-idiom",
284                    "Recognize Hexagon-specific loop idioms", false, false)
285
286template <typename FuncT>
287void Simplifier::Context::traverse(Value *V, FuncT F) {
288  WorkListType Q;
289  Q.push_back(V);
290
291  while (!Q.empty()) {
292    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
293    if (!U || U->getParent())
294      continue;
295    if (!F(U))
296      continue;
297    for (Value *Op : U->operands())
298      Q.push_back(Op);
299  }
300}
301
302void Simplifier::Context::print(raw_ostream &OS, const Value *V) const {
303  const auto *U = dyn_cast<const Instruction>(V);
304  if (!U) {
305    OS << V << '(' << *V << ')';
306    return;
307  }
308
309  if (U->getParent()) {
310    OS << U << '(';
311    U->printAsOperand(OS, true);
312    OS << ')';
313    return;
314  }
315
316  unsigned N = U->getNumOperands();
317  if (N != 0)
318    OS << U << '(';
319  OS << U->getOpcodeName();
320  for (const Value *Op : U->operands()) {
321    OS << ' ';
322    print(OS, Op);
323  }
324  if (N != 0)
325    OS << ')';
326}
327
328void Simplifier::Context::initialize(Instruction *Exp) {
329  // Perform a deep clone of the expression, set Root to the root
330  // of the clone, and build a map from the cloned values to the
331  // original ones.
332  ValueMapType M;
333  BasicBlock *Block = Exp->getParent();
334  WorkListType Q;
335  Q.push_back(Exp);
336
337  while (!Q.empty()) {
338    Value *V = Q.pop_front_val();
339    if (M.contains(V))
340      continue;
341    if (Instruction *U = dyn_cast<Instruction>(V)) {
342      if (isa<PHINode>(U) || U->getParent() != Block)
343        continue;
344      for (Value *Op : U->operands())
345        Q.push_back(Op);
346      M.insert({U, U->clone()});
347    }
348  }
349
350  for (std::pair<Value*,Value*> P : M) {
351    Instruction *U = cast<Instruction>(P.second);
352    for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) {
353      auto F = M.find(U->getOperand(i));
354      if (F != M.end())
355        U->setOperand(i, F->second);
356    }
357  }
358
359  auto R = M.find(Exp);
360  assert(R != M.end());
361  Root = R->second;
362
363  record(Root);
364  use(Root);
365}
366
367void Simplifier::Context::record(Value *V) {
368  auto Record = [this](Instruction *U) -> bool {
369    Clones.insert(U);
370    return true;
371  };
372  traverse(V, Record);
373}
374
375void Simplifier::Context::use(Value *V) {
376  auto Use = [this](Instruction *U) -> bool {
377    Used.insert(U);
378    return true;
379  };
380  traverse(V, Use);
381}
382
383void Simplifier::Context::unuse(Value *V) {
384  if (!isa<Instruction>(V) || cast<Instruction>(V)->getParent() != nullptr)
385    return;
386
387  auto Unuse = [this](Instruction *U) -> bool {
388    if (!U->use_empty())
389      return false;
390    Used.erase(U);
391    return true;
392  };
393  traverse(V, Unuse);
394}
395
396Value *Simplifier::Context::subst(Value *Tree, Value *OldV, Value *NewV) {
397  if (Tree == OldV)
398    return NewV;
399  if (OldV == NewV)
400    return Tree;
401
402  WorkListType Q;
403  Q.push_back(Tree);
404  while (!Q.empty()) {
405    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
406    // If U is not an instruction, or it's not a clone, skip it.
407    if (!U || U->getParent())
408      continue;
409    for (unsigned i = 0, n = U->getNumOperands(); i != n; ++i) {
410      Value *Op = U->getOperand(i);
411      if (Op == OldV) {
412        U->setOperand(i, NewV);
413        unuse(OldV);
414      } else {
415        Q.push_back(Op);
416      }
417    }
418  }
419  return Tree;
420}
421
422void Simplifier::Context::replace(Value *OldV, Value *NewV) {
423  if (Root == OldV) {
424    Root = NewV;
425    use(Root);
426    return;
427  }
428
429  // NewV may be a complex tree that has just been created by one of the
430  // transformation rules. We need to make sure that it is commoned with
431  // the existing Root to the maximum extent possible.
432  // Identify all subtrees of NewV (including NewV itself) that have
433  // equivalent counterparts in Root, and replace those subtrees with
434  // these counterparts.
435  WorkListType Q;
436  Q.push_back(NewV);
437  while (!Q.empty()) {
438    Value *V = Q.pop_front_val();
439    Instruction *U = dyn_cast<Instruction>(V);
440    if (!U || U->getParent())
441      continue;
442    if (Value *DupV = find(Root, V)) {
443      if (DupV != V)
444        NewV = subst(NewV, V, DupV);
445    } else {
446      for (Value *Op : U->operands())
447        Q.push_back(Op);
448    }
449  }
450
451  // Now, simply replace OldV with NewV in Root.
452  Root = subst(Root, OldV, NewV);
453  use(Root);
454}
455
456void Simplifier::Context::cleanup() {
457  for (Value *V : Clones) {
458    Instruction *U = cast<Instruction>(V);
459    if (!U->getParent())
460      U->dropAllReferences();
461  }
462
463  for (Value *V : Clones) {
464    Instruction *U = cast<Instruction>(V);
465    if (!U->getParent())
466      U->deleteValue();
467  }
468}
469
470bool Simplifier::Context::equal(const Instruction *I,
471                                const Instruction *J) const {
472  if (I == J)
473    return true;
474  if (!I->isSameOperationAs(J))
475    return false;
476  if (isa<PHINode>(I))
477    return I->isIdenticalTo(J);
478
479  for (unsigned i = 0, n = I->getNumOperands(); i != n; ++i) {
480    Value *OpI = I->getOperand(i), *OpJ = J->getOperand(i);
481    if (OpI == OpJ)
482      continue;
483    auto *InI = dyn_cast<const Instruction>(OpI);
484    auto *InJ = dyn_cast<const Instruction>(OpJ);
485    if (InI && InJ) {
486      if (!equal(InI, InJ))
487        return false;
488    } else if (InI != InJ || !InI)
489      return false;
490  }
491  return true;
492}
493
494Value *Simplifier::Context::find(Value *Tree, Value *Sub) const {
495  Instruction *SubI = dyn_cast<Instruction>(Sub);
496  WorkListType Q;
497  Q.push_back(Tree);
498
499  while (!Q.empty()) {
500    Value *V = Q.pop_front_val();
501    if (V == Sub)
502      return V;
503    Instruction *U = dyn_cast<Instruction>(V);
504    if (!U || U->getParent())
505      continue;
506    if (SubI && equal(SubI, U))
507      return U;
508    assert(!isa<PHINode>(U));
509    for (Value *Op : U->operands())
510      Q.push_back(Op);
511  }
512  return nullptr;
513}
514
515void Simplifier::Context::link(Instruction *I, BasicBlock *B,
516      BasicBlock::iterator At) {
517  if (I->getParent())
518    return;
519
520  for (Value *Op : I->operands()) {
521    if (Instruction *OpI = dyn_cast<Instruction>(Op))
522      link(OpI, B, At);
523  }
524
525  I->insertInto(B, At);
526}
527
528Value *Simplifier::Context::materialize(BasicBlock *B,
529      BasicBlock::iterator At) {
530  if (Instruction *RootI = dyn_cast<Instruction>(Root))
531    link(RootI, B, At);
532  return Root;
533}
534
535Value *Simplifier::simplify(Context &C) {
536  WorkListType Q;
537  Q.push_back(C.Root);
538  unsigned Count = 0;
539  const unsigned Limit = SimplifyLimit;
540
541  while (!Q.empty()) {
542    if (Count++ >= Limit)
543      break;
544    Instruction *U = dyn_cast<Instruction>(Q.pop_front_val());
545    if (!U || U->getParent() || !C.Used.count(U))
546      continue;
547    bool Changed = false;
548    for (Rule &R : Rules) {
549      Value *W = R.Fn(U, C.Ctx);
550      if (!W)
551        continue;
552      Changed = true;
553      C.record(W);
554      C.replace(U, W);
555      Q.push_back(C.Root);
556      break;
557    }
558    if (!Changed) {
559      for (Value *Op : U->operands())
560        Q.push_back(Op);
561    }
562  }
563  return Count < Limit ? C.Root : nullptr;
564}
565
566//===----------------------------------------------------------------------===//
567//
568//          Implementation of PolynomialMultiplyRecognize
569//
570//===----------------------------------------------------------------------===//
571
572namespace {
573
574  class PolynomialMultiplyRecognize {
575  public:
576    explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl,
577        const DominatorTree &dt, const TargetLibraryInfo &tli,
578        ScalarEvolution &se)
579      : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {}
580
581    bool recognize();
582
583  private:
584    using ValueSeq = SetVector<Value *>;
585
586    IntegerType *getPmpyType() const {
587      LLVMContext &Ctx = CurLoop->getHeader()->getParent()->getContext();
588      return IntegerType::get(Ctx, 32);
589    }
590
591    bool isPromotableTo(Value *V, IntegerType *Ty);
592    void promoteTo(Instruction *In, IntegerType *DestTy, BasicBlock *LoopB);
593    bool promoteTypes(BasicBlock *LoopB, BasicBlock *ExitB);
594
595    Value *getCountIV(BasicBlock *BB);
596    bool findCycle(Value *Out, Value *In, ValueSeq &Cycle);
597    void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early,
598          ValueSeq &Late);
599    bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late);
600    bool commutesWithShift(Instruction *I);
601    bool highBitsAreZero(Value *V, unsigned IterCount);
602    bool keepsHighBitsZero(Value *V, unsigned IterCount);
603    bool isOperandShifted(Instruction *I, Value *Op);
604    bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB,
605          unsigned IterCount);
606    void cleanupLoopBody(BasicBlock *LoopB);
607
608    struct ParsedValues {
609      ParsedValues() = default;
610
611      Value *M = nullptr;
612      Value *P = nullptr;
613      Value *Q = nullptr;
614      Value *R = nullptr;
615      Value *X = nullptr;
616      Instruction *Res = nullptr;
617      unsigned IterCount = 0;
618      bool Left = false;
619      bool Inv = false;
620    };
621
622    bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV);
623    bool matchRightShift(SelectInst *SelI, ParsedValues &PV);
624    bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB,
625          Value *CIV, ParsedValues &PV, bool PreScan);
626    unsigned getInverseMxN(unsigned QP);
627    Value *generate(BasicBlock::iterator At, ParsedValues &PV);
628
629    void setupPreSimplifier(Simplifier &S);
630    void setupPostSimplifier(Simplifier &S);
631
632    Loop *CurLoop;
633    const DataLayout &DL;
634    const DominatorTree &DT;
635    const TargetLibraryInfo &TLI;
636    ScalarEvolution &SE;
637  };
638
639} // end anonymous namespace
640
641Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) {
642  pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
643  if (std::distance(PI, PE) != 2)
644    return nullptr;
645  BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI;
646
647  for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) {
648    auto *PN = cast<PHINode>(I);
649    Value *InitV = PN->getIncomingValueForBlock(PB);
650    if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero())
651      continue;
652    Value *IterV = PN->getIncomingValueForBlock(BB);
653    auto *BO = dyn_cast<BinaryOperator>(IterV);
654    if (!BO)
655      continue;
656    if (BO->getOpcode() != Instruction::Add)
657      continue;
658    Value *IncV = nullptr;
659    if (BO->getOperand(0) == PN)
660      IncV = BO->getOperand(1);
661    else if (BO->getOperand(1) == PN)
662      IncV = BO->getOperand(0);
663    if (IncV == nullptr)
664      continue;
665
666    if (auto *T = dyn_cast<ConstantInt>(IncV))
667      if (T->isOne())
668        return PN;
669  }
670  return nullptr;
671}
672
673static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) {
674  for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) {
675    Use &TheUse = UI.getUse();
676    ++UI;
677    if (auto *II = dyn_cast<Instruction>(TheUse.getUser()))
678      if (BB == II->getParent())
679        II->replaceUsesOfWith(I, J);
680  }
681}
682
683bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI,
684      Value *CIV, ParsedValues &PV) {
685  // Match the following:
686  //   select (X & (1 << i)) != 0 ? R ^ (Q << i) : R
687  //   select (X & (1 << i)) == 0 ? R : R ^ (Q << i)
688  // The condition may also check for equality with the masked value, i.e
689  //   select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R
690  //   select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i);
691
692  Value *CondV = SelI->getCondition();
693  Value *TrueV = SelI->getTrueValue();
694  Value *FalseV = SelI->getFalseValue();
695
696  using namespace PatternMatch;
697
698  CmpInst::Predicate P;
699  Value *A = nullptr, *B = nullptr, *C = nullptr;
700
701  if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) &&
702      !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B)))))
703    return false;
704  if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
705    return false;
706  // Matched: select (A & B) == C ? ... : ...
707  //          select (A & B) != C ? ... : ...
708
709  Value *X = nullptr, *Sh1 = nullptr;
710  // Check (A & B) for (X & (1 << i)):
711  if (match(A, m_Shl(m_One(), m_Specific(CIV)))) {
712    Sh1 = A;
713    X = B;
714  } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) {
715    Sh1 = B;
716    X = A;
717  } else {
718    // TODO: Could also check for an induction variable containing single
719    // bit shifted left by 1 in each iteration.
720    return false;
721  }
722
723  bool TrueIfZero;
724
725  // Check C against the possible values for comparison: 0 and (1 << i):
726  if (match(C, m_Zero()))
727    TrueIfZero = (P == CmpInst::ICMP_EQ);
728  else if (C == Sh1)
729    TrueIfZero = (P == CmpInst::ICMP_NE);
730  else
731    return false;
732
733  // So far, matched:
734  //   select (X & (1 << i)) ? ... : ...
735  // including variations of the check against zero/non-zero value.
736
737  Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr;
738  if (TrueIfZero) {
739    ShouldSameV = TrueV;
740    ShouldXoredV = FalseV;
741  } else {
742    ShouldSameV = FalseV;
743    ShouldXoredV = TrueV;
744  }
745
746  Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr;
747  Value *T = nullptr;
748  if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) {
749    // Matched: select +++ ? ... : Y ^ Z
750    //          select +++ ? Y ^ Z : ...
751    // where +++ denotes previously checked matches.
752    if (ShouldSameV == Y)
753      T = Z;
754    else if (ShouldSameV == Z)
755      T = Y;
756    else
757      return false;
758    R = ShouldSameV;
759    // Matched: select +++ ? R : R ^ T
760    //          select +++ ? R ^ T : R
761    // depending on TrueIfZero.
762
763  } else if (match(ShouldSameV, m_Zero())) {
764    // Matched: select +++ ? 0 : ...
765    //          select +++ ? ... : 0
766    if (!SelI->hasOneUse())
767      return false;
768    T = ShouldXoredV;
769    // Matched: select +++ ? 0 : T
770    //          select +++ ? T : 0
771
772    Value *U = *SelI->user_begin();
773    if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) &&
774        !match(U, m_Xor(m_Value(R), m_Specific(SelI))))
775      return false;
776    // Matched: xor (select +++ ? 0 : T), R
777    //          xor (select +++ ? T : 0), R
778  } else
779    return false;
780
781  // The xor input value T is isolated into its own match so that it could
782  // be checked against an induction variable containing a shifted bit
783  // (todo).
784  // For now, check against (Q << i).
785  if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) &&
786      !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV)))))
787    return false;
788  // Matched: select +++ ? R : R ^ (Q << i)
789  //          select +++ ? R ^ (Q << i) : R
790
791  PV.X = X;
792  PV.Q = Q;
793  PV.R = R;
794  PV.Left = true;
795  return true;
796}
797
798bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI,
799      ParsedValues &PV) {
800  // Match the following:
801  //   select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1)
802  //   select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q
803  // The condition may also check for equality with the masked value, i.e
804  //   select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1)
805  //   select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q
806
807  Value *CondV = SelI->getCondition();
808  Value *TrueV = SelI->getTrueValue();
809  Value *FalseV = SelI->getFalseValue();
810
811  using namespace PatternMatch;
812
813  Value *C = nullptr;
814  CmpInst::Predicate P;
815  bool TrueIfZero;
816
817  if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) ||
818      match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))) {
819    if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
820      return false;
821    // Matched: select C == 0 ? ... : ...
822    //          select C != 0 ? ... : ...
823    TrueIfZero = (P == CmpInst::ICMP_EQ);
824  } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) ||
825             match(CondV, m_ICmp(P, m_One(), m_Value(C)))) {
826    if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
827      return false;
828    // Matched: select C == 1 ? ... : ...
829    //          select C != 1 ? ... : ...
830    TrueIfZero = (P == CmpInst::ICMP_NE);
831  } else
832    return false;
833
834  Value *X = nullptr;
835  if (!match(C, m_And(m_Value(X), m_One())) &&
836      !match(C, m_And(m_One(), m_Value(X))))
837    return false;
838  // Matched: select (X & 1) == +++ ? ... : ...
839  //          select (X & 1) != +++ ? ... : ...
840
841  Value *R = nullptr, *Q = nullptr;
842  if (TrueIfZero) {
843    // The select's condition is true if the tested bit is 0.
844    // TrueV must be the shift, FalseV must be the xor.
845    if (!match(TrueV, m_LShr(m_Value(R), m_One())))
846      return false;
847    // Matched: select +++ ? (R >> 1) : ...
848    if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) &&
849        !match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV))))
850      return false;
851    // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q
852    // with commuting ^.
853  } else {
854    // The select's condition is true if the tested bit is 1.
855    // TrueV must be the xor, FalseV must be the shift.
856    if (!match(FalseV, m_LShr(m_Value(R), m_One())))
857      return false;
858    // Matched: select +++ ? ... : (R >> 1)
859    if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) &&
860        !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV))))
861      return false;
862    // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1)
863    // with commuting ^.
864  }
865
866  PV.X = X;
867  PV.Q = Q;
868  PV.R = R;
869  PV.Left = false;
870  return true;
871}
872
873bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI,
874      BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV,
875      bool PreScan) {
876  using namespace PatternMatch;
877
878  // The basic pattern for R = P.Q is:
879  // for i = 0..31
880  //   R = phi (0, R')
881  //   if (P & (1 << i))        ; test-bit(P, i)
882  //     R' = R ^ (Q << i)
883  //
884  // Similarly, the basic pattern for R = (P/Q).Q - P
885  // for i = 0..31
886  //   R = phi(P, R')
887  //   if (R & (1 << i))
888  //     R' = R ^ (Q << i)
889
890  // There exist idioms, where instead of Q being shifted left, P is shifted
891  // right. This produces a result that is shifted right by 32 bits (the
892  // non-shifted result is 64-bit).
893  //
894  // For R = P.Q, this would be:
895  // for i = 0..31
896  //   R = phi (0, R')
897  //   if ((P >> i) & 1)
898  //     R' = (R >> 1) ^ Q      ; R is cycled through the loop, so it must
899  //   else                     ; be shifted by 1, not i.
900  //     R' = R >> 1
901  //
902  // And for the inverse:
903  // for i = 0..31
904  //   R = phi (P, R')
905  //   if (R & 1)
906  //     R' = (R >> 1) ^ Q
907  //   else
908  //     R' = R >> 1
909
910  // The left-shifting idioms share the same pattern:
911  //   select (X & (1 << i)) ? R ^ (Q << i) : R
912  // Similarly for right-shifting idioms:
913  //   select (X & 1) ? (R >> 1) ^ Q
914
915  if (matchLeftShift(SelI, CIV, PV)) {
916    // If this is a pre-scan, getting this far is sufficient.
917    if (PreScan)
918      return true;
919
920    // Need to make sure that the SelI goes back into R.
921    auto *RPhi = dyn_cast<PHINode>(PV.R);
922    if (!RPhi)
923      return false;
924    if (SelI != RPhi->getIncomingValueForBlock(LoopB))
925      return false;
926    PV.Res = SelI;
927
928    // If X is loop invariant, it must be the input polynomial, and the
929    // idiom is the basic polynomial multiply.
930    if (CurLoop->isLoopInvariant(PV.X)) {
931      PV.P = PV.X;
932      PV.Inv = false;
933    } else {
934      // X is not loop invariant. If X == R, this is the inverse pmpy.
935      // Otherwise, check for an xor with an invariant value. If the
936      // variable argument to the xor is R, then this is still a valid
937      // inverse pmpy.
938      PV.Inv = true;
939      if (PV.X != PV.R) {
940        Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr;
941        if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2))))
942          return false;
943        auto *I1 = dyn_cast<Instruction>(X1);
944        auto *I2 = dyn_cast<Instruction>(X2);
945        if (!I1 || I1->getParent() != LoopB) {
946          Var = X2;
947          Inv = X1;
948        } else if (!I2 || I2->getParent() != LoopB) {
949          Var = X1;
950          Inv = X2;
951        } else
952          return false;
953        if (Var != PV.R)
954          return false;
955        PV.M = Inv;
956      }
957      // The input polynomial P still needs to be determined. It will be
958      // the entry value of R.
959      Value *EntryP = RPhi->getIncomingValueForBlock(PrehB);
960      PV.P = EntryP;
961    }
962
963    return true;
964  }
965
966  if (matchRightShift(SelI, PV)) {
967    // If this is an inverse pattern, the Q polynomial must be known at
968    // compile time.
969    if (PV.Inv && !isa<ConstantInt>(PV.Q))
970      return false;
971    if (PreScan)
972      return true;
973    // There is no exact matching of right-shift pmpy.
974    return false;
975  }
976
977  return false;
978}
979
980bool PolynomialMultiplyRecognize::isPromotableTo(Value *Val,
981      IntegerType *DestTy) {
982  IntegerType *T = dyn_cast<IntegerType>(Val->getType());
983  if (!T || T->getBitWidth() > DestTy->getBitWidth())
984    return false;
985  if (T->getBitWidth() == DestTy->getBitWidth())
986    return true;
987  // Non-instructions are promotable. The reason why an instruction may not
988  // be promotable is that it may produce a different result if its operands
989  // and the result are promoted, for example, it may produce more non-zero
990  // bits. While it would still be possible to represent the proper result
991  // in a wider type, it may require adding additional instructions (which
992  // we don't want to do).
993  Instruction *In = dyn_cast<Instruction>(Val);
994  if (!In)
995    return true;
996  // The bitwidth of the source type is smaller than the destination.
997  // Check if the individual operation can be promoted.
998  switch (In->getOpcode()) {
999    case Instruction::PHI:
1000    case Instruction::ZExt:
1001    case Instruction::And:
1002    case Instruction::Or:
1003    case Instruction::Xor:
1004    case Instruction::LShr: // Shift right is ok.
1005    case Instruction::Select:
1006    case Instruction::Trunc:
1007      return true;
1008    case Instruction::ICmp:
1009      if (CmpInst *CI = cast<CmpInst>(In))
1010        return CI->isEquality() || CI->isUnsigned();
1011      llvm_unreachable("Cast failed unexpectedly");
1012    case Instruction::Add:
1013      return In->hasNoSignedWrap() && In->hasNoUnsignedWrap();
1014  }
1015  return false;
1016}
1017
1018void PolynomialMultiplyRecognize::promoteTo(Instruction *In,
1019      IntegerType *DestTy, BasicBlock *LoopB) {
1020  Type *OrigTy = In->getType();
1021  assert(!OrigTy->isVoidTy() && "Invalid instruction to promote");
1022
1023  // Leave boolean values alone.
1024  if (!In->getType()->isIntegerTy(1))
1025    In->mutateType(DestTy);
1026  unsigned DestBW = DestTy->getBitWidth();
1027
1028  // Handle PHIs.
1029  if (PHINode *P = dyn_cast<PHINode>(In)) {
1030    unsigned N = P->getNumIncomingValues();
1031    for (unsigned i = 0; i != N; ++i) {
1032      BasicBlock *InB = P->getIncomingBlock(i);
1033      if (InB == LoopB)
1034        continue;
1035      Value *InV = P->getIncomingValue(i);
1036      IntegerType *Ty = cast<IntegerType>(InV->getType());
1037      // Do not promote values in PHI nodes of type i1.
1038      if (Ty != P->getType()) {
1039        // If the value type does not match the PHI type, the PHI type
1040        // must have been promoted.
1041        assert(Ty->getBitWidth() < DestBW);
1042        InV = IRBuilder<>(InB->getTerminator()).CreateZExt(InV, DestTy);
1043        P->setIncomingValue(i, InV);
1044      }
1045    }
1046  } else if (ZExtInst *Z = dyn_cast<ZExtInst>(In)) {
1047    Value *Op = Z->getOperand(0);
1048    if (Op->getType() == Z->getType())
1049      Z->replaceAllUsesWith(Op);
1050    Z->eraseFromParent();
1051    return;
1052  }
1053  if (TruncInst *T = dyn_cast<TruncInst>(In)) {
1054    IntegerType *TruncTy = cast<IntegerType>(OrigTy);
1055    Value *Mask = ConstantInt::get(DestTy, (1u << TruncTy->getBitWidth()) - 1);
1056    Value *And = IRBuilder<>(In).CreateAnd(T->getOperand(0), Mask);
1057    T->replaceAllUsesWith(And);
1058    T->eraseFromParent();
1059    return;
1060  }
1061
1062  // Promote immediates.
1063  for (unsigned i = 0, n = In->getNumOperands(); i != n; ++i) {
1064    if (ConstantInt *CI = dyn_cast<ConstantInt>(In->getOperand(i)))
1065      if (CI->getBitWidth() < DestBW)
1066        In->setOperand(i, ConstantInt::get(DestTy, CI->getZExtValue()));
1067  }
1068}
1069
1070bool PolynomialMultiplyRecognize::promoteTypes(BasicBlock *LoopB,
1071      BasicBlock *ExitB) {
1072  assert(LoopB);
1073  // Skip loops where the exit block has more than one predecessor. The values
1074  // coming from the loop block will be promoted to another type, and so the
1075  // values coming into the exit block from other predecessors would also have
1076  // to be promoted.
1077  if (!ExitB || (ExitB->getSinglePredecessor() != LoopB))
1078    return false;
1079  IntegerType *DestTy = getPmpyType();
1080  // Check if the exit values have types that are no wider than the type
1081  // that we want to promote to.
1082  unsigned DestBW = DestTy->getBitWidth();
1083  for (PHINode &P : ExitB->phis()) {
1084    if (P.getNumIncomingValues() != 1)
1085      return false;
1086    assert(P.getIncomingBlock(0) == LoopB);
1087    IntegerType *T = dyn_cast<IntegerType>(P.getType());
1088    if (!T || T->getBitWidth() > DestBW)
1089      return false;
1090  }
1091
1092  // Check all instructions in the loop.
1093  for (Instruction &In : *LoopB)
1094    if (!In.isTerminator() && !isPromotableTo(&In, DestTy))
1095      return false;
1096
1097  // Perform the promotion.
1098  std::vector<Instruction*> LoopIns;
1099  std::transform(LoopB->begin(), LoopB->end(), std::back_inserter(LoopIns),
1100                 [](Instruction &In) { return &In; });
1101  for (Instruction *In : LoopIns)
1102    if (!In->isTerminator())
1103      promoteTo(In, DestTy, LoopB);
1104
1105  // Fix up the PHI nodes in the exit block.
1106  Instruction *EndI = ExitB->getFirstNonPHI();
1107  BasicBlock::iterator End = EndI ? EndI->getIterator() : ExitB->end();
1108  for (auto I = ExitB->begin(); I != End; ++I) {
1109    PHINode *P = dyn_cast<PHINode>(I);
1110    if (!P)
1111      break;
1112    Type *Ty0 = P->getIncomingValue(0)->getType();
1113    Type *PTy = P->getType();
1114    if (PTy != Ty0) {
1115      assert(Ty0 == DestTy);
1116      // In order to create the trunc, P must have the promoted type.
1117      P->mutateType(Ty0);
1118      Value *T = IRBuilder<>(ExitB, End).CreateTrunc(P, PTy);
1119      // In order for the RAUW to work, the types of P and T must match.
1120      P->mutateType(PTy);
1121      P->replaceAllUsesWith(T);
1122      // Final update of the P's type.
1123      P->mutateType(Ty0);
1124      cast<Instruction>(T)->setOperand(0, P);
1125    }
1126  }
1127
1128  return true;
1129}
1130
1131bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In,
1132      ValueSeq &Cycle) {
1133  // Out = ..., In, ...
1134  if (Out == In)
1135    return true;
1136
1137  auto *BB = cast<Instruction>(Out)->getParent();
1138  bool HadPhi = false;
1139
1140  for (auto *U : Out->users()) {
1141    auto *I = dyn_cast<Instruction>(&*U);
1142    if (I == nullptr || I->getParent() != BB)
1143      continue;
1144    // Make sure that there are no multi-iteration cycles, e.g.
1145    //   p1 = phi(p2)
1146    //   p2 = phi(p1)
1147    // The cycle p1->p2->p1 would span two loop iterations.
1148    // Check that there is only one phi in the cycle.
1149    bool IsPhi = isa<PHINode>(I);
1150    if (IsPhi && HadPhi)
1151      return false;
1152    HadPhi |= IsPhi;
1153    if (!Cycle.insert(I))
1154      return false;
1155    if (findCycle(I, In, Cycle))
1156      break;
1157    Cycle.remove(I);
1158  }
1159  return !Cycle.empty();
1160}
1161
1162void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI,
1163      ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) {
1164  // All the values in the cycle that are between the phi node and the
1165  // divider instruction will be classified as "early", all other values
1166  // will be "late".
1167
1168  bool IsE = true;
1169  unsigned I, N = Cycle.size();
1170  for (I = 0; I < N; ++I) {
1171    Value *V = Cycle[I];
1172    if (DivI == V)
1173      IsE = false;
1174    else if (!isa<PHINode>(V))
1175      continue;
1176    // Stop if found either.
1177    break;
1178  }
1179  // "I" is the index of either DivI or the phi node, whichever was first.
1180  // "E" is "false" or "true" respectively.
1181  ValueSeq &First = !IsE ? Early : Late;
1182  for (unsigned J = 0; J < I; ++J)
1183    First.insert(Cycle[J]);
1184
1185  ValueSeq &Second = IsE ? Early : Late;
1186  Second.insert(Cycle[I]);
1187  for (++I; I < N; ++I) {
1188    Value *V = Cycle[I];
1189    if (DivI == V || isa<PHINode>(V))
1190      break;
1191    Second.insert(V);
1192  }
1193
1194  for (; I < N; ++I)
1195    First.insert(Cycle[I]);
1196}
1197
1198bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI,
1199      ValueSeq &Early, ValueSeq &Late) {
1200  // Select is an exception, since the condition value does not have to be
1201  // classified in the same way as the true/false values. The true/false
1202  // values do have to be both early or both late.
1203  if (UseI->getOpcode() == Instruction::Select) {
1204    Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2);
1205    if (Early.count(TV) || Early.count(FV)) {
1206      if (Late.count(TV) || Late.count(FV))
1207        return false;
1208      Early.insert(UseI);
1209    } else if (Late.count(TV) || Late.count(FV)) {
1210      if (Early.count(TV) || Early.count(FV))
1211        return false;
1212      Late.insert(UseI);
1213    }
1214    return true;
1215  }
1216
1217  // Not sure what would be the example of this, but the code below relies
1218  // on having at least one operand.
1219  if (UseI->getNumOperands() == 0)
1220    return true;
1221
1222  bool AE = true, AL = true;
1223  for (auto &I : UseI->operands()) {
1224    if (Early.count(&*I))
1225      AL = false;
1226    else if (Late.count(&*I))
1227      AE = false;
1228  }
1229  // If the operands appear "all early" and "all late" at the same time,
1230  // then it means that none of them are actually classified as either.
1231  // This is harmless.
1232  if (AE && AL)
1233    return true;
1234  // Conversely, if they are neither "all early" nor "all late", then
1235  // we have a mixture of early and late operands that is not a known
1236  // exception.
1237  if (!AE && !AL)
1238    return false;
1239
1240  // Check that we have covered the two special cases.
1241  assert(AE != AL);
1242
1243  if (AE)
1244    Early.insert(UseI);
1245  else
1246    Late.insert(UseI);
1247  return true;
1248}
1249
1250bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) {
1251  switch (I->getOpcode()) {
1252    case Instruction::And:
1253    case Instruction::Or:
1254    case Instruction::Xor:
1255    case Instruction::LShr:
1256    case Instruction::Shl:
1257    case Instruction::Select:
1258    case Instruction::ICmp:
1259    case Instruction::PHI:
1260      break;
1261    default:
1262      return false;
1263  }
1264  return true;
1265}
1266
1267bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V,
1268      unsigned IterCount) {
1269  auto *T = dyn_cast<IntegerType>(V->getType());
1270  if (!T)
1271    return false;
1272
1273  KnownBits Known(T->getBitWidth());
1274  computeKnownBits(V, Known, DL);
1275  return Known.countMinLeadingZeros() >= IterCount;
1276}
1277
1278bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V,
1279      unsigned IterCount) {
1280  // Assume that all inputs to the value have the high bits zero.
1281  // Check if the value itself preserves the zeros in the high bits.
1282  if (auto *C = dyn_cast<ConstantInt>(V))
1283    return C->getValue().countl_zero() >= IterCount;
1284
1285  if (auto *I = dyn_cast<Instruction>(V)) {
1286    switch (I->getOpcode()) {
1287      case Instruction::And:
1288      case Instruction::Or:
1289      case Instruction::Xor:
1290      case Instruction::LShr:
1291      case Instruction::Select:
1292      case Instruction::ICmp:
1293      case Instruction::PHI:
1294      case Instruction::ZExt:
1295        return true;
1296    }
1297  }
1298
1299  return false;
1300}
1301
1302bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) {
1303  unsigned Opc = I->getOpcode();
1304  if (Opc == Instruction::Shl || Opc == Instruction::LShr)
1305    return Op != I->getOperand(1);
1306  return true;
1307}
1308
1309bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB,
1310      BasicBlock *ExitB, unsigned IterCount) {
1311  Value *CIV = getCountIV(LoopB);
1312  if (CIV == nullptr)
1313    return false;
1314  auto *CIVTy = dyn_cast<IntegerType>(CIV->getType());
1315  if (CIVTy == nullptr)
1316    return false;
1317
1318  ValueSeq RShifts;
1319  ValueSeq Early, Late, Cycled;
1320
1321  // Find all value cycles that contain logical right shifts by 1.
1322  for (Instruction &I : *LoopB) {
1323    using namespace PatternMatch;
1324
1325    Value *V = nullptr;
1326    if (!match(&I, m_LShr(m_Value(V), m_One())))
1327      continue;
1328    ValueSeq C;
1329    if (!findCycle(&I, V, C))
1330      continue;
1331
1332    // Found a cycle.
1333    C.insert(&I);
1334    classifyCycle(&I, C, Early, Late);
1335    Cycled.insert(C.begin(), C.end());
1336    RShifts.insert(&I);
1337  }
1338
1339  // Find the set of all values affected by the shift cycles, i.e. all
1340  // cycled values, and (recursively) all their users.
1341  ValueSeq Users(Cycled.begin(), Cycled.end());
1342  for (unsigned i = 0; i < Users.size(); ++i) {
1343    Value *V = Users[i];
1344    if (!isa<IntegerType>(V->getType()))
1345      return false;
1346    auto *R = cast<Instruction>(V);
1347    // If the instruction does not commute with shifts, the loop cannot
1348    // be unshifted.
1349    if (!commutesWithShift(R))
1350      return false;
1351    for (User *U : R->users()) {
1352      auto *T = cast<Instruction>(U);
1353      // Skip users from outside of the loop. They will be handled later.
1354      // Also, skip the right-shifts and phi nodes, since they mix early
1355      // and late values.
1356      if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T))
1357        continue;
1358
1359      Users.insert(T);
1360      if (!classifyInst(T, Early, Late))
1361        return false;
1362    }
1363  }
1364
1365  if (Users.empty())
1366    return false;
1367
1368  // Verify that high bits remain zero.
1369  ValueSeq Internal(Users.begin(), Users.end());
1370  ValueSeq Inputs;
1371  for (unsigned i = 0; i < Internal.size(); ++i) {
1372    auto *R = dyn_cast<Instruction>(Internal[i]);
1373    if (!R)
1374      continue;
1375    for (Value *Op : R->operands()) {
1376      auto *T = dyn_cast<Instruction>(Op);
1377      if (T && T->getParent() != LoopB)
1378        Inputs.insert(Op);
1379      else
1380        Internal.insert(Op);
1381    }
1382  }
1383  for (Value *V : Inputs)
1384    if (!highBitsAreZero(V, IterCount))
1385      return false;
1386  for (Value *V : Internal)
1387    if (!keepsHighBitsZero(V, IterCount))
1388      return false;
1389
1390  // Finally, the work can be done. Unshift each user.
1391  IRBuilder<> IRB(LoopB);
1392  std::map<Value*,Value*> ShiftMap;
1393
1394  using CastMapType = std::map<std::pair<Value *, Type *>, Value *>;
1395
1396  CastMapType CastMap;
1397
1398  auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V,
1399        IntegerType *Ty) -> Value* {
1400    auto H = CM.find(std::make_pair(V, Ty));
1401    if (H != CM.end())
1402      return H->second;
1403    Value *CV = IRB.CreateIntCast(V, Ty, false);
1404    CM.insert(std::make_pair(std::make_pair(V, Ty), CV));
1405    return CV;
1406  };
1407
1408  for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) {
1409    using namespace PatternMatch;
1410
1411    if (isa<PHINode>(I) || !Users.count(&*I))
1412      continue;
1413
1414    // Match lshr x, 1.
1415    Value *V = nullptr;
1416    if (match(&*I, m_LShr(m_Value(V), m_One()))) {
1417      replaceAllUsesOfWithIn(&*I, V, LoopB);
1418      continue;
1419    }
1420    // For each non-cycled operand, replace it with the corresponding
1421    // value shifted left.
1422    for (auto &J : I->operands()) {
1423      Value *Op = J.get();
1424      if (!isOperandShifted(&*I, Op))
1425        continue;
1426      if (Users.count(Op))
1427        continue;
1428      // Skip shifting zeros.
1429      if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
1430        continue;
1431      // Check if we have already generated a shift for this value.
1432      auto F = ShiftMap.find(Op);
1433      Value *W = (F != ShiftMap.end()) ? F->second : nullptr;
1434      if (W == nullptr) {
1435        IRB.SetInsertPoint(&*I);
1436        // First, the shift amount will be CIV or CIV+1, depending on
1437        // whether the value is early or late. Instead of creating CIV+1,
1438        // do a single shift of the value.
1439        Value *ShAmt = CIV, *ShVal = Op;
1440        auto *VTy = cast<IntegerType>(ShVal->getType());
1441        auto *ATy = cast<IntegerType>(ShAmt->getType());
1442        if (Late.count(&*I))
1443          ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1));
1444        // Second, the types of the shifted value and the shift amount
1445        // must match.
1446        if (VTy != ATy) {
1447          if (VTy->getBitWidth() < ATy->getBitWidth())
1448            ShVal = upcast(CastMap, IRB, ShVal, ATy);
1449          else
1450            ShAmt = upcast(CastMap, IRB, ShAmt, VTy);
1451        }
1452        // Ready to generate the shift and memoize it.
1453        W = IRB.CreateShl(ShVal, ShAmt);
1454        ShiftMap.insert(std::make_pair(Op, W));
1455      }
1456      I->replaceUsesOfWith(Op, W);
1457    }
1458  }
1459
1460  // Update the users outside of the loop to account for having left
1461  // shifts. They would normally be shifted right in the loop, so shift
1462  // them right after the loop exit.
1463  // Take advantage of the loop-closed SSA form, which has all the post-
1464  // loop values in phi nodes.
1465  IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt());
1466  for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) {
1467    if (!isa<PHINode>(P))
1468      break;
1469    auto *PN = cast<PHINode>(P);
1470    Value *U = PN->getIncomingValueForBlock(LoopB);
1471    if (!Users.count(U))
1472      continue;
1473    Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount));
1474    PN->replaceAllUsesWith(S);
1475    // The above RAUW will create
1476    //   S = lshr S, IterCount
1477    // so we need to fix it back into
1478    //   S = lshr PN, IterCount
1479    cast<User>(S)->replaceUsesOfWith(S, PN);
1480  }
1481
1482  return true;
1483}
1484
1485void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) {
1486  for (auto &I : *LoopB)
1487    if (Value *SV = simplifyInstruction(&I, {DL, &TLI, &DT}))
1488      I.replaceAllUsesWith(SV);
1489
1490  for (Instruction &I : llvm::make_early_inc_range(*LoopB))
1491    RecursivelyDeleteTriviallyDeadInstructions(&I, &TLI);
1492}
1493
1494unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) {
1495  // Arrays of coefficients of Q and the inverse, C.
1496  // Q[i] = coefficient at x^i.
1497  std::array<char,32> Q, C;
1498
1499  for (unsigned i = 0; i < 32; ++i) {
1500    Q[i] = QP & 1;
1501    QP >>= 1;
1502  }
1503  assert(Q[0] == 1);
1504
1505  // Find C, such that
1506  // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1
1507  //
1508  // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the
1509  // operations * and + are & and ^ respectively.
1510  //
1511  // Find C[i] recursively, by comparing i-th coefficient in the product
1512  // with 0 (or 1 for i=0).
1513  //
1514  // C[0] = 1, since C[0] = Q[0], and Q[0] = 1.
1515  C[0] = 1;
1516  for (unsigned i = 1; i < 32; ++i) {
1517    // Solve for C[i] in:
1518    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0
1519    // This is equivalent to
1520    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0
1521    // which is
1522    //   C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i]
1523    unsigned T = 0;
1524    for (unsigned j = 0; j < i; ++j)
1525      T = T ^ (C[j] & Q[i-j]);
1526    C[i] = T;
1527  }
1528
1529  unsigned QV = 0;
1530  for (unsigned i = 0; i < 32; ++i)
1531    if (C[i])
1532      QV |= (1 << i);
1533
1534  return QV;
1535}
1536
1537Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
1538      ParsedValues &PV) {
1539  IRBuilder<> B(&*At);
1540  Module *M = At->getParent()->getParent()->getParent();
1541  Function *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw);
1542
1543  Value *P = PV.P, *Q = PV.Q, *P0 = P;
1544  unsigned IC = PV.IterCount;
1545
1546  if (PV.M != nullptr)
1547    P0 = P = B.CreateXor(P, PV.M);
1548
1549  // Create a bit mask to clear the high bits beyond IterCount.
1550  auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC));
1551
1552  if (PV.IterCount != 32)
1553    P = B.CreateAnd(P, BMI);
1554
1555  if (PV.Inv) {
1556    auto *QI = dyn_cast<ConstantInt>(PV.Q);
1557    assert(QI && QI->getBitWidth() <= 32);
1558
1559    // Again, clearing bits beyond IterCount.
1560    unsigned M = (1 << PV.IterCount) - 1;
1561    unsigned Tmp = (QI->getZExtValue() | 1) & M;
1562    unsigned QV = getInverseMxN(Tmp) & M;
1563    auto *QVI = ConstantInt::get(QI->getType(), QV);
1564    P = B.CreateCall(PMF, {P, QVI});
1565    P = B.CreateTrunc(P, QI->getType());
1566    if (IC != 32)
1567      P = B.CreateAnd(P, BMI);
1568  }
1569
1570  Value *R = B.CreateCall(PMF, {P, Q});
1571
1572  if (PV.M != nullptr)
1573    R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false));
1574
1575  return R;
1576}
1577
1578static bool hasZeroSignBit(const Value *V) {
1579  if (const auto *CI = dyn_cast<const ConstantInt>(V))
1580    return CI->getValue().isNonNegative();
1581  const Instruction *I = dyn_cast<const Instruction>(V);
1582  if (!I)
1583    return false;
1584  switch (I->getOpcode()) {
1585    case Instruction::LShr:
1586      if (const auto SI = dyn_cast<const ConstantInt>(I->getOperand(1)))
1587        return SI->getZExtValue() > 0;
1588      return false;
1589    case Instruction::Or:
1590    case Instruction::Xor:
1591      return hasZeroSignBit(I->getOperand(0)) &&
1592             hasZeroSignBit(I->getOperand(1));
1593    case Instruction::And:
1594      return hasZeroSignBit(I->getOperand(0)) ||
1595             hasZeroSignBit(I->getOperand(1));
1596  }
1597  return false;
1598}
1599
1600void PolynomialMultiplyRecognize::setupPreSimplifier(Simplifier &S) {
1601  S.addRule("sink-zext",
1602    // Sink zext past bitwise operations.
1603    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1604      if (I->getOpcode() != Instruction::ZExt)
1605        return nullptr;
1606      Instruction *T = dyn_cast<Instruction>(I->getOperand(0));
1607      if (!T)
1608        return nullptr;
1609      switch (T->getOpcode()) {
1610        case Instruction::And:
1611        case Instruction::Or:
1612        case Instruction::Xor:
1613          break;
1614        default:
1615          return nullptr;
1616      }
1617      IRBuilder<> B(Ctx);
1618      return B.CreateBinOp(cast<BinaryOperator>(T)->getOpcode(),
1619                           B.CreateZExt(T->getOperand(0), I->getType()),
1620                           B.CreateZExt(T->getOperand(1), I->getType()));
1621    });
1622  S.addRule("xor/and -> and/xor",
1623    // (xor (and x a) (and y a)) -> (and (xor x y) a)
1624    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1625      if (I->getOpcode() != Instruction::Xor)
1626        return nullptr;
1627      Instruction *And0 = dyn_cast<Instruction>(I->getOperand(0));
1628      Instruction *And1 = dyn_cast<Instruction>(I->getOperand(1));
1629      if (!And0 || !And1)
1630        return nullptr;
1631      if (And0->getOpcode() != Instruction::And ||
1632          And1->getOpcode() != Instruction::And)
1633        return nullptr;
1634      if (And0->getOperand(1) != And1->getOperand(1))
1635        return nullptr;
1636      IRBuilder<> B(Ctx);
1637      return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1->getOperand(0)),
1638                         And0->getOperand(1));
1639    });
1640  S.addRule("sink binop into select",
1641    // (Op (select c x y) z) -> (select c (Op x z) (Op y z))
1642    // (Op x (select c y z)) -> (select c (Op x y) (Op x z))
1643    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1644      BinaryOperator *BO = dyn_cast<BinaryOperator>(I);
1645      if (!BO)
1646        return nullptr;
1647      Instruction::BinaryOps Op = BO->getOpcode();
1648      if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(0))) {
1649        IRBuilder<> B(Ctx);
1650        Value *X = Sel->getTrueValue(), *Y = Sel->getFalseValue();
1651        Value *Z = BO->getOperand(1);
1652        return B.CreateSelect(Sel->getCondition(),
1653                              B.CreateBinOp(Op, X, Z),
1654                              B.CreateBinOp(Op, Y, Z));
1655      }
1656      if (SelectInst *Sel = dyn_cast<SelectInst>(BO->getOperand(1))) {
1657        IRBuilder<> B(Ctx);
1658        Value *X = BO->getOperand(0);
1659        Value *Y = Sel->getTrueValue(), *Z = Sel->getFalseValue();
1660        return B.CreateSelect(Sel->getCondition(),
1661                              B.CreateBinOp(Op, X, Y),
1662                              B.CreateBinOp(Op, X, Z));
1663      }
1664      return nullptr;
1665    });
1666  S.addRule("fold select-select",
1667    // (select c (select c x y) z) -> (select c x z)
1668    // (select c x (select c y z)) -> (select c x z)
1669    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1670      SelectInst *Sel = dyn_cast<SelectInst>(I);
1671      if (!Sel)
1672        return nullptr;
1673      IRBuilder<> B(Ctx);
1674      Value *C = Sel->getCondition();
1675      if (SelectInst *Sel0 = dyn_cast<SelectInst>(Sel->getTrueValue())) {
1676        if (Sel0->getCondition() == C)
1677          return B.CreateSelect(C, Sel0->getTrueValue(), Sel->getFalseValue());
1678      }
1679      if (SelectInst *Sel1 = dyn_cast<SelectInst>(Sel->getFalseValue())) {
1680        if (Sel1->getCondition() == C)
1681          return B.CreateSelect(C, Sel->getTrueValue(), Sel1->getFalseValue());
1682      }
1683      return nullptr;
1684    });
1685  S.addRule("or-signbit -> xor-signbit",
1686    // (or (lshr x 1) 0x800.0) -> (xor (lshr x 1) 0x800.0)
1687    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1688      if (I->getOpcode() != Instruction::Or)
1689        return nullptr;
1690      ConstantInt *Msb = dyn_cast<ConstantInt>(I->getOperand(1));
1691      if (!Msb || !Msb->getValue().isSignMask())
1692        return nullptr;
1693      if (!hasZeroSignBit(I->getOperand(0)))
1694        return nullptr;
1695      return IRBuilder<>(Ctx).CreateXor(I->getOperand(0), Msb);
1696    });
1697  S.addRule("sink lshr into binop",
1698    // (lshr (BitOp x y) c) -> (BitOp (lshr x c) (lshr y c))
1699    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1700      if (I->getOpcode() != Instruction::LShr)
1701        return nullptr;
1702      BinaryOperator *BitOp = dyn_cast<BinaryOperator>(I->getOperand(0));
1703      if (!BitOp)
1704        return nullptr;
1705      switch (BitOp->getOpcode()) {
1706        case Instruction::And:
1707        case Instruction::Or:
1708        case Instruction::Xor:
1709          break;
1710        default:
1711          return nullptr;
1712      }
1713      IRBuilder<> B(Ctx);
1714      Value *S = I->getOperand(1);
1715      return B.CreateBinOp(BitOp->getOpcode(),
1716                B.CreateLShr(BitOp->getOperand(0), S),
1717                B.CreateLShr(BitOp->getOperand(1), S));
1718    });
1719  S.addRule("expose bitop-const",
1720    // (BitOp1 (BitOp2 x a) b) -> (BitOp2 x (BitOp1 a b))
1721    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1722      auto IsBitOp = [](unsigned Op) -> bool {
1723        switch (Op) {
1724          case Instruction::And:
1725          case Instruction::Or:
1726          case Instruction::Xor:
1727            return true;
1728        }
1729        return false;
1730      };
1731      BinaryOperator *BitOp1 = dyn_cast<BinaryOperator>(I);
1732      if (!BitOp1 || !IsBitOp(BitOp1->getOpcode()))
1733        return nullptr;
1734      BinaryOperator *BitOp2 = dyn_cast<BinaryOperator>(BitOp1->getOperand(0));
1735      if (!BitOp2 || !IsBitOp(BitOp2->getOpcode()))
1736        return nullptr;
1737      ConstantInt *CA = dyn_cast<ConstantInt>(BitOp2->getOperand(1));
1738      ConstantInt *CB = dyn_cast<ConstantInt>(BitOp1->getOperand(1));
1739      if (!CA || !CB)
1740        return nullptr;
1741      IRBuilder<> B(Ctx);
1742      Value *X = BitOp2->getOperand(0);
1743      return B.CreateBinOp(BitOp2->getOpcode(), X,
1744                B.CreateBinOp(BitOp1->getOpcode(), CA, CB));
1745    });
1746}
1747
1748void PolynomialMultiplyRecognize::setupPostSimplifier(Simplifier &S) {
1749  S.addRule("(and (xor (and x a) y) b) -> (and (xor x y) b), if b == b&a",
1750    [](Instruction *I, LLVMContext &Ctx) -> Value* {
1751      if (I->getOpcode() != Instruction::And)
1752        return nullptr;
1753      Instruction *Xor = dyn_cast<Instruction>(I->getOperand(0));
1754      ConstantInt *C0 = dyn_cast<ConstantInt>(I->getOperand(1));
1755      if (!Xor || !C0)
1756        return nullptr;
1757      if (Xor->getOpcode() != Instruction::Xor)
1758        return nullptr;
1759      Instruction *And0 = dyn_cast<Instruction>(Xor->getOperand(0));
1760      Instruction *And1 = dyn_cast<Instruction>(Xor->getOperand(1));
1761      // Pick the first non-null and.
1762      if (!And0 || And0->getOpcode() != Instruction::And)
1763        std::swap(And0, And1);
1764      ConstantInt *C1 = dyn_cast<ConstantInt>(And0->getOperand(1));
1765      if (!C1)
1766        return nullptr;
1767      uint32_t V0 = C0->getZExtValue();
1768      uint32_t V1 = C1->getZExtValue();
1769      if (V0 != (V0 & V1))
1770        return nullptr;
1771      IRBuilder<> B(Ctx);
1772      return B.CreateAnd(B.CreateXor(And0->getOperand(0), And1), C0);
1773    });
1774}
1775
1776bool PolynomialMultiplyRecognize::recognize() {
1777  LLVM_DEBUG(dbgs() << "Starting PolynomialMultiplyRecognize on loop\n"
1778                    << *CurLoop << '\n');
1779  // Restrictions:
1780  // - The loop must consist of a single block.
1781  // - The iteration count must be known at compile-time.
1782  // - The loop must have an induction variable starting from 0, and
1783  //   incremented in each iteration of the loop.
1784  BasicBlock *LoopB = CurLoop->getHeader();
1785  LLVM_DEBUG(dbgs() << "Loop header:\n" << *LoopB);
1786
1787  if (LoopB != CurLoop->getLoopLatch())
1788    return false;
1789  BasicBlock *ExitB = CurLoop->getExitBlock();
1790  if (ExitB == nullptr)
1791    return false;
1792  BasicBlock *EntryB = CurLoop->getLoopPreheader();
1793  if (EntryB == nullptr)
1794    return false;
1795
1796  unsigned IterCount = 0;
1797  const SCEV *CT = SE.getBackedgeTakenCount(CurLoop);
1798  if (isa<SCEVCouldNotCompute>(CT))
1799    return false;
1800  if (auto *CV = dyn_cast<SCEVConstant>(CT))
1801    IterCount = CV->getValue()->getZExtValue() + 1;
1802
1803  Value *CIV = getCountIV(LoopB);
1804  ParsedValues PV;
1805  Simplifier PreSimp;
1806  PV.IterCount = IterCount;
1807  LLVM_DEBUG(dbgs() << "Loop IV: " << *CIV << "\nIterCount: " << IterCount
1808                    << '\n');
1809
1810  setupPreSimplifier(PreSimp);
1811
1812  // Perform a preliminary scan of select instructions to see if any of them
1813  // looks like a generator of the polynomial multiply steps. Assume that a
1814  // loop can only contain a single transformable operation, so stop the
1815  // traversal after the first reasonable candidate was found.
1816  // XXX: Currently this approach can modify the loop before being 100% sure
1817  // that the transformation can be carried out.
1818  bool FoundPreScan = false;
1819  auto FeedsPHI = [LoopB](const Value *V) -> bool {
1820    for (const Value *U : V->users()) {
1821      if (const auto *P = dyn_cast<const PHINode>(U))
1822        if (P->getParent() == LoopB)
1823          return true;
1824    }
1825    return false;
1826  };
1827  for (Instruction &In : *LoopB) {
1828    SelectInst *SI = dyn_cast<SelectInst>(&In);
1829    if (!SI || !FeedsPHI(SI))
1830      continue;
1831
1832    Simplifier::Context C(SI);
1833    Value *T = PreSimp.simplify(C);
1834    SelectInst *SelI = (T && isa<SelectInst>(T)) ? cast<SelectInst>(T) : SI;
1835    LLVM_DEBUG(dbgs() << "scanSelect(pre-scan): " << PE(C, SelI) << '\n');
1836    if (scanSelect(SelI, LoopB, EntryB, CIV, PV, true)) {
1837      FoundPreScan = true;
1838      if (SelI != SI) {
1839        Value *NewSel = C.materialize(LoopB, SI->getIterator());
1840        SI->replaceAllUsesWith(NewSel);
1841        RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI);
1842      }
1843      break;
1844    }
1845  }
1846
1847  if (!FoundPreScan) {
1848    LLVM_DEBUG(dbgs() << "Have not found candidates for pmpy\n");
1849    return false;
1850  }
1851
1852  if (!PV.Left) {
1853    // The right shift version actually only returns the higher bits of
1854    // the result (each iteration discards the LSB). If we want to convert it
1855    // to a left-shifting loop, the working data type must be at least as
1856    // wide as the target's pmpy instruction.
1857    if (!promoteTypes(LoopB, ExitB))
1858      return false;
1859    // Run post-promotion simplifications.
1860    Simplifier PostSimp;
1861    setupPostSimplifier(PostSimp);
1862    for (Instruction &In : *LoopB) {
1863      SelectInst *SI = dyn_cast<SelectInst>(&In);
1864      if (!SI || !FeedsPHI(SI))
1865        continue;
1866      Simplifier::Context C(SI);
1867      Value *T = PostSimp.simplify(C);
1868      SelectInst *SelI = dyn_cast_or_null<SelectInst>(T);
1869      if (SelI != SI) {
1870        Value *NewSel = C.materialize(LoopB, SI->getIterator());
1871        SI->replaceAllUsesWith(NewSel);
1872        RecursivelyDeleteTriviallyDeadInstructions(SI, &TLI);
1873      }
1874      break;
1875    }
1876
1877    if (!convertShiftsToLeft(LoopB, ExitB, IterCount))
1878      return false;
1879    cleanupLoopBody(LoopB);
1880  }
1881
1882  // Scan the loop again, find the generating select instruction.
1883  bool FoundScan = false;
1884  for (Instruction &In : *LoopB) {
1885    SelectInst *SelI = dyn_cast<SelectInst>(&In);
1886    if (!SelI)
1887      continue;
1888    LLVM_DEBUG(dbgs() << "scanSelect: " << *SelI << '\n');
1889    FoundScan = scanSelect(SelI, LoopB, EntryB, CIV, PV, false);
1890    if (FoundScan)
1891      break;
1892  }
1893  assert(FoundScan);
1894
1895  LLVM_DEBUG({
1896    StringRef PP = (PV.M ? "(P+M)" : "P");
1897    if (!PV.Inv)
1898      dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n";
1899    else
1900      dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + "
1901             << PP << "\n";
1902    dbgs() << "  Res:" << *PV.Res << "\n  P:" << *PV.P << "\n";
1903    if (PV.M)
1904      dbgs() << "  M:" << *PV.M << "\n";
1905    dbgs() << "  Q:" << *PV.Q << "\n";
1906    dbgs() << "  Iteration count:" << PV.IterCount << "\n";
1907  });
1908
1909  BasicBlock::iterator At(EntryB->getTerminator());
1910  Value *PM = generate(At, PV);
1911  if (PM == nullptr)
1912    return false;
1913
1914  if (PM->getType() != PV.Res->getType())
1915    PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false);
1916
1917  PV.Res->replaceAllUsesWith(PM);
1918  PV.Res->eraseFromParent();
1919  return true;
1920}
1921
1922int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) {
1923  if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1)))
1924    return SC->getAPInt().getSExtValue();
1925  return 0;
1926}
1927
1928bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) {
1929  // Allow volatile stores if HexagonVolatileMemcpy is enabled.
1930  if (!(SI->isVolatile() && HexagonVolatileMemcpy) && !SI->isSimple())
1931    return false;
1932
1933  Value *StoredVal = SI->getValueOperand();
1934  Value *StorePtr = SI->getPointerOperand();
1935
1936  // Reject stores that are so large that they overflow an unsigned.
1937  uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType());
1938  if ((SizeInBits & 7) || (SizeInBits >> 32) != 0)
1939    return false;
1940
1941  // See if the pointer expression is an AddRec like {base,+,1} on the current
1942  // loop, which indicates a strided store.  If we have something else, it's a
1943  // random store we can't handle.
1944  auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
1945  if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
1946    return false;
1947
1948  // Check to see if the stride matches the size of the store.  If so, then we
1949  // know that every byte is touched in the loop.
1950  int Stride = getSCEVStride(StoreEv);
1951  if (Stride == 0)
1952    return false;
1953  unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
1954  if (StoreSize != unsigned(std::abs(Stride)))
1955    return false;
1956
1957  // The store must be feeding a non-volatile load.
1958  LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
1959  if (!LI || !LI->isSimple())
1960    return false;
1961
1962  // See if the pointer expression is an AddRec like {base,+,1} on the current
1963  // loop, which indicates a strided load.  If we have something else, it's a
1964  // random load we can't handle.
1965  Value *LoadPtr = LI->getPointerOperand();
1966  auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr));
1967  if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
1968    return false;
1969
1970  // The store and load must share the same stride.
1971  if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
1972    return false;
1973
1974  // Success.  This store can be converted into a memcpy.
1975  return true;
1976}
1977
1978/// mayLoopAccessLocation - Return true if the specified loop might access the
1979/// specified pointer location, which is a loop-strided access.  The 'Access'
1980/// argument specifies what the verboten forms of access are (read or write).
1981static bool
1982mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
1983                      const SCEV *BECount, unsigned StoreSize,
1984                      AliasAnalysis &AA,
1985                      SmallPtrSetImpl<Instruction *> &Ignored) {
1986  // Get the location that may be stored across the loop.  Since the access
1987  // is strided positively through memory, we say that the modified location
1988  // starts at the pointer and has infinite size.
1989  LocationSize AccessSize = LocationSize::afterPointer();
1990
1991  // If the loop iterates a fixed number of times, we can refine the access
1992  // size to be exactly the size of the memset, which is (BECount+1)*StoreSize
1993  if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
1994    AccessSize = LocationSize::precise((BECst->getValue()->getZExtValue() + 1) *
1995                                       StoreSize);
1996
1997  // TODO: For this to be really effective, we have to dive into the pointer
1998  // operand in the store.  Store to &A[i] of 100 will always return may alias
1999  // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
2000  // which will then no-alias a store to &A[100].
2001  MemoryLocation StoreLoc(Ptr, AccessSize);
2002
2003  for (auto *B : L->blocks())
2004    for (auto &I : *B)
2005      if (Ignored.count(&I) == 0 &&
2006          isModOrRefSet(AA.getModRefInfo(&I, StoreLoc) & Access))
2007        return true;
2008
2009  return false;
2010}
2011
2012void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB,
2013      SmallVectorImpl<StoreInst*> &Stores) {
2014  Stores.clear();
2015  for (Instruction &I : *BB)
2016    if (StoreInst *SI = dyn_cast<StoreInst>(&I))
2017      if (isLegalStore(CurLoop, SI))
2018        Stores.push_back(SI);
2019}
2020
2021bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop,
2022      StoreInst *SI, const SCEV *BECount) {
2023  assert((SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy)) &&
2024         "Expected only non-volatile stores, or Hexagon-specific memcpy"
2025         "to volatile destination.");
2026
2027  Value *StorePtr = SI->getPointerOperand();
2028  auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
2029  unsigned Stride = getSCEVStride(StoreEv);
2030  unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType());
2031  if (Stride != StoreSize)
2032    return false;
2033
2034  // See if the pointer expression is an AddRec like {base,+,1} on the current
2035  // loop, which indicates a strided load.  If we have something else, it's a
2036  // random load we can't handle.
2037  auto *LI = cast<LoadInst>(SI->getValueOperand());
2038  auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
2039
2040  // The trip count of the loop and the base pointer of the addrec SCEV is
2041  // guaranteed to be loop invariant, which means that it should dominate the
2042  // header.  This allows us to insert code for it in the preheader.
2043  BasicBlock *Preheader = CurLoop->getLoopPreheader();
2044  Instruction *ExpPt = Preheader->getTerminator();
2045  IRBuilder<> Builder(ExpPt);
2046  SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom");
2047
2048  Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace());
2049
2050  // Okay, we have a strided store "p[i]" of a loaded value.  We can turn
2051  // this into a memcpy/memmove in the loop preheader now if we want.  However,
2052  // this would be unsafe to do if there is anything else in the loop that may
2053  // read or write the memory region we're storing to.  For memcpy, this
2054  // includes the load that feeds the stores.  Check for an alias by generating
2055  // the base address and checking everything.
2056  Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(),
2057      Builder.getPtrTy(SI->getPointerAddressSpace()), ExpPt);
2058  Value *LoadBasePtr = nullptr;
2059
2060  bool Overlap = false;
2061  bool DestVolatile = SI->isVolatile();
2062  Type *BECountTy = BECount->getType();
2063
2064  if (DestVolatile) {
2065    // The trip count must fit in i32, since it is the type of the "num_words"
2066    // argument to hexagon_memcpy_forward_vp4cp4n2.
2067    if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) {
2068CleanupAndExit:
2069      // If we generated new code for the base pointer, clean up.
2070      Expander.clear();
2071      if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) {
2072        RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI);
2073        StoreBasePtr = nullptr;
2074      }
2075      if (LoadBasePtr) {
2076        RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI);
2077        LoadBasePtr = nullptr;
2078      }
2079      return false;
2080    }
2081  }
2082
2083  SmallPtrSet<Instruction*, 2> Ignore1;
2084  Ignore1.insert(SI);
2085  if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount,
2086                            StoreSize, *AA, Ignore1)) {
2087    // Check if the load is the offending instruction.
2088    Ignore1.insert(LI);
2089    if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop,
2090                              BECount, StoreSize, *AA, Ignore1)) {
2091      // Still bad. Nothing we can do.
2092      goto CleanupAndExit;
2093    }
2094    // It worked with the load ignored.
2095    Overlap = true;
2096  }
2097
2098  if (!Overlap) {
2099    if (DisableMemcpyIdiom || !HasMemcpy)
2100      goto CleanupAndExit;
2101  } else {
2102    // Don't generate memmove if this function will be inlined. This is
2103    // because the caller will undergo this transformation after inlining.
2104    Function *Func = CurLoop->getHeader()->getParent();
2105    if (Func->hasFnAttribute(Attribute::AlwaysInline))
2106      goto CleanupAndExit;
2107
2108    // In case of a memmove, the call to memmove will be executed instead
2109    // of the loop, so we need to make sure that there is nothing else in
2110    // the loop than the load, store and instructions that these two depend
2111    // on.
2112    SmallVector<Instruction*,2> Insts;
2113    Insts.push_back(SI);
2114    Insts.push_back(LI);
2115    if (!coverLoop(CurLoop, Insts))
2116      goto CleanupAndExit;
2117
2118    if (DisableMemmoveIdiom || !HasMemmove)
2119      goto CleanupAndExit;
2120    bool IsNested = CurLoop->getParentLoop() != nullptr;
2121    if (IsNested && OnlyNonNestedMemmove)
2122      goto CleanupAndExit;
2123  }
2124
2125  // For a memcpy, we have to make sure that the input array is not being
2126  // mutated by the loop.
2127  LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(),
2128      Builder.getPtrTy(LI->getPointerAddressSpace()), ExpPt);
2129
2130  SmallPtrSet<Instruction*, 2> Ignore2;
2131  Ignore2.insert(SI);
2132  if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount,
2133                            StoreSize, *AA, Ignore2))
2134    goto CleanupAndExit;
2135
2136  // Check the stride.
2137  bool StridePos = getSCEVStride(LoadEv) >= 0;
2138
2139  // Currently, the volatile memcpy only emulates traversing memory forward.
2140  if (!StridePos && DestVolatile)
2141    goto CleanupAndExit;
2142
2143  bool RuntimeCheck = (Overlap || DestVolatile);
2144
2145  BasicBlock *ExitB;
2146  if (RuntimeCheck) {
2147    // The runtime check needs a single exit block.
2148    SmallVector<BasicBlock*, 8> ExitBlocks;
2149    CurLoop->getUniqueExitBlocks(ExitBlocks);
2150    if (ExitBlocks.size() != 1)
2151      goto CleanupAndExit;
2152    ExitB = ExitBlocks[0];
2153  }
2154
2155  // The # stored bytes is (BECount+1)*Size.  Expand the trip count out to
2156  // pointer size if it isn't already.
2157  LLVMContext &Ctx = SI->getContext();
2158  BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);
2159  DebugLoc DLoc = SI->getDebugLoc();
2160
2161  const SCEV *NumBytesS =
2162      SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);
2163  if (StoreSize != 1)
2164    NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
2165                               SCEV::FlagNUW);
2166  Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt);
2167  if (Instruction *In = dyn_cast<Instruction>(NumBytes))
2168    if (Value *Simp = simplifyInstruction(In, {*DL, TLI, DT}))
2169      NumBytes = Simp;
2170
2171  CallInst *NewCall;
2172
2173  if (RuntimeCheck) {
2174    unsigned Threshold = RuntimeMemSizeThreshold;
2175    if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) {
2176      uint64_t C = CI->getZExtValue();
2177      if (Threshold != 0 && C < Threshold)
2178        goto CleanupAndExit;
2179      if (C < CompileTimeMemSizeThreshold)
2180        goto CleanupAndExit;
2181    }
2182
2183    BasicBlock *Header = CurLoop->getHeader();
2184    Function *Func = Header->getParent();
2185    Loop *ParentL = LF->getLoopFor(Preheader);
2186    StringRef HeaderName = Header->getName();
2187
2188    // Create a new (empty) preheader, and update the PHI nodes in the
2189    // header to use the new preheader.
2190    BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph",
2191                                                  Func, Header);
2192    if (ParentL)
2193      ParentL->addBasicBlockToLoop(NewPreheader, *LF);
2194    IRBuilder<>(NewPreheader).CreateBr(Header);
2195    for (auto &In : *Header) {
2196      PHINode *PN = dyn_cast<PHINode>(&In);
2197      if (!PN)
2198        break;
2199      int bx = PN->getBasicBlockIndex(Preheader);
2200      if (bx >= 0)
2201        PN->setIncomingBlock(bx, NewPreheader);
2202    }
2203    DT->addNewBlock(NewPreheader, Preheader);
2204    DT->changeImmediateDominator(Header, NewPreheader);
2205
2206    // Check for safe conditions to execute memmove.
2207    // If stride is positive, copying things from higher to lower addresses
2208    // is equivalent to memmove.  For negative stride, it's the other way
2209    // around.  Copying forward in memory with positive stride may not be
2210    // same as memmove since we may be copying values that we just stored
2211    // in some previous iteration.
2212    Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy);
2213    Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy);
2214    Value *LowA = StridePos ? SA : LA;
2215    Value *HighA = StridePos ? LA : SA;
2216    Value *CmpA = Builder.CreateICmpULT(LowA, HighA);
2217    Value *Cond = CmpA;
2218
2219    // Check for distance between pointers. Since the case LowA < HighA
2220    // is checked for above, assume LowA >= HighA.
2221    Value *Dist = Builder.CreateSub(LowA, HighA);
2222    Value *CmpD = Builder.CreateICmpSLE(NumBytes, Dist);
2223    Value *CmpEither = Builder.CreateOr(Cond, CmpD);
2224    Cond = CmpEither;
2225
2226    if (Threshold != 0) {
2227      Type *Ty = NumBytes->getType();
2228      Value *Thr = ConstantInt::get(Ty, Threshold);
2229      Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes);
2230      Value *CmpBoth = Builder.CreateAnd(Cond, CmpB);
2231      Cond = CmpBoth;
2232    }
2233    BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli",
2234                                              Func, NewPreheader);
2235    if (ParentL)
2236      ParentL->addBasicBlockToLoop(MemmoveB, *LF);
2237    Instruction *OldT = Preheader->getTerminator();
2238    Builder.CreateCondBr(Cond, MemmoveB, NewPreheader);
2239    OldT->eraseFromParent();
2240    Preheader->setName(Preheader->getName()+".old");
2241    DT->addNewBlock(MemmoveB, Preheader);
2242    // Find the new immediate dominator of the exit block.
2243    BasicBlock *ExitD = Preheader;
2244    for (BasicBlock *PB : predecessors(ExitB)) {
2245      ExitD = DT->findNearestCommonDominator(ExitD, PB);
2246      if (!ExitD)
2247        break;
2248    }
2249    // If the prior immediate dominator of ExitB was dominated by the
2250    // old preheader, then the old preheader becomes the new immediate
2251    // dominator.  Otherwise don't change anything (because the newly
2252    // added blocks are dominated by the old preheader).
2253    if (ExitD && DT->dominates(Preheader, ExitD)) {
2254      DomTreeNode *BN = DT->getNode(ExitB);
2255      DomTreeNode *DN = DT->getNode(ExitD);
2256      BN->setIDom(DN);
2257    }
2258
2259    // Add a call to memmove to the conditional block.
2260    IRBuilder<> CondBuilder(MemmoveB);
2261    CondBuilder.CreateBr(ExitB);
2262    CondBuilder.SetInsertPoint(MemmoveB->getTerminator());
2263
2264    if (DestVolatile) {
2265      Type *Int32Ty = Type::getInt32Ty(Ctx);
2266      Type *PtrTy = PointerType::get(Ctx, 0);
2267      Type *VoidTy = Type::getVoidTy(Ctx);
2268      Module *M = Func->getParent();
2269      FunctionCallee Fn = M->getOrInsertFunction(
2270          HexagonVolatileMemcpyName, VoidTy, PtrTy, PtrTy, Int32Ty);
2271
2272      const SCEV *OneS = SE->getConstant(Int32Ty, 1);
2273      const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty);
2274      const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW);
2275      Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty,
2276                                               MemmoveB->getTerminator());
2277      if (Instruction *In = dyn_cast<Instruction>(NumWords))
2278        if (Value *Simp = simplifyInstruction(In, {*DL, TLI, DT}))
2279          NumWords = Simp;
2280
2281      NewCall = CondBuilder.CreateCall(Fn,
2282                                       {StoreBasePtr, LoadBasePtr, NumWords});
2283    } else {
2284      NewCall = CondBuilder.CreateMemMove(
2285          StoreBasePtr, SI->getAlign(), LoadBasePtr, LI->getAlign(), NumBytes);
2286    }
2287  } else {
2288    NewCall = Builder.CreateMemCpy(StoreBasePtr, SI->getAlign(), LoadBasePtr,
2289                                   LI->getAlign(), NumBytes);
2290    // Okay, the memcpy has been formed.  Zap the original store and
2291    // anything that feeds into it.
2292    RecursivelyDeleteTriviallyDeadInstructions(SI, TLI);
2293  }
2294
2295  NewCall->setDebugLoc(DLoc);
2296
2297  LLVM_DEBUG(dbgs() << "  Formed " << (Overlap ? "memmove: " : "memcpy: ")
2298                    << *NewCall << "\n"
2299                    << "    from load ptr=" << *LoadEv << " at: " << *LI << "\n"
2300                    << "    from store ptr=" << *StoreEv << " at: " << *SI
2301                    << "\n");
2302
2303  return true;
2304}
2305
2306// Check if the instructions in Insts, together with their dependencies
2307// cover the loop in the sense that the loop could be safely eliminated once
2308// the instructions in Insts are removed.
2309bool HexagonLoopIdiomRecognize::coverLoop(Loop *L,
2310      SmallVectorImpl<Instruction*> &Insts) const {
2311  SmallSet<BasicBlock*,8> LoopBlocks;
2312  for (auto *B : L->blocks())
2313    LoopBlocks.insert(B);
2314
2315  SetVector<Instruction*> Worklist(Insts.begin(), Insts.end());
2316
2317  // Collect all instructions from the loop that the instructions in Insts
2318  // depend on (plus their dependencies, etc.).  These instructions will
2319  // constitute the expression trees that feed those in Insts, but the trees
2320  // will be limited only to instructions contained in the loop.
2321  for (unsigned i = 0; i < Worklist.size(); ++i) {
2322    Instruction *In = Worklist[i];
2323    for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) {
2324      Instruction *OpI = dyn_cast<Instruction>(I);
2325      if (!OpI)
2326        continue;
2327      BasicBlock *PB = OpI->getParent();
2328      if (!LoopBlocks.count(PB))
2329        continue;
2330      Worklist.insert(OpI);
2331    }
2332  }
2333
2334  // Scan all instructions in the loop, if any of them have a user outside
2335  // of the loop, or outside of the expressions collected above, then either
2336  // the loop has a side-effect visible outside of it, or there are
2337  // instructions in it that are not involved in the original set Insts.
2338  for (auto *B : L->blocks()) {
2339    for (auto &In : *B) {
2340      if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In))
2341        continue;
2342      if (!Worklist.count(&In) && In.mayHaveSideEffects())
2343        return false;
2344      for (auto *K : In.users()) {
2345        Instruction *UseI = dyn_cast<Instruction>(K);
2346        if (!UseI)
2347          continue;
2348        BasicBlock *UseB = UseI->getParent();
2349        if (LF->getLoopFor(UseB) != L)
2350          return false;
2351      }
2352    }
2353  }
2354
2355  return true;
2356}
2357
2358/// runOnLoopBlock - Process the specified block, which lives in a counted loop
2359/// with the specified backedge count.  This block is known to be in the current
2360/// loop and not in any subloops.
2361bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB,
2362      const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) {
2363  // We can only promote stores in this block if they are unconditionally
2364  // executed in the loop.  For a block to be unconditionally executed, it has
2365  // to dominate all the exit blocks of the loop.  Verify this now.
2366  auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool {
2367    return DT->dominates(BB, EB);
2368  };
2369  if (!all_of(ExitBlocks, DominatedByBB))
2370    return false;
2371
2372  bool MadeChange = false;
2373  // Look for store instructions, which may be optimized to memset/memcpy.
2374  SmallVector<StoreInst*,8> Stores;
2375  collectStores(CurLoop, BB, Stores);
2376
2377  // Optimize the store into a memcpy, if it feeds an similarly strided load.
2378  for (auto &SI : Stores)
2379    MadeChange |= processCopyingStore(CurLoop, SI, BECount);
2380
2381  return MadeChange;
2382}
2383
2384bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) {
2385  PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE);
2386  if (PMR.recognize())
2387    return true;
2388
2389  if (!HasMemcpy && !HasMemmove)
2390    return false;
2391
2392  const SCEV *BECount = SE->getBackedgeTakenCount(L);
2393  assert(!isa<SCEVCouldNotCompute>(BECount) &&
2394         "runOnCountableLoop() called on a loop without a predictable"
2395         "backedge-taken count");
2396
2397  SmallVector<BasicBlock *, 8> ExitBlocks;
2398  L->getUniqueExitBlocks(ExitBlocks);
2399
2400  bool Changed = false;
2401
2402  // Scan all the blocks in the loop that are not in subloops.
2403  for (auto *BB : L->getBlocks()) {
2404    // Ignore blocks in subloops.
2405    if (LF->getLoopFor(BB) != L)
2406      continue;
2407    Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks);
2408  }
2409
2410  return Changed;
2411}
2412
2413bool HexagonLoopIdiomRecognize::run(Loop *L) {
2414  const Module &M = *L->getHeader()->getParent()->getParent();
2415  if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon)
2416    return false;
2417
2418  // If the loop could not be converted to canonical form, it must have an
2419  // indirectbr in it, just give up.
2420  if (!L->getLoopPreheader())
2421    return false;
2422
2423  // Disable loop idiom recognition if the function's name is a common idiom.
2424  StringRef Name = L->getHeader()->getParent()->getName();
2425  if (Name == "memset" || Name == "memcpy" || Name == "memmove")
2426    return false;
2427
2428  DL = &L->getHeader()->getModule()->getDataLayout();
2429
2430  HasMemcpy = TLI->has(LibFunc_memcpy);
2431  HasMemmove = TLI->has(LibFunc_memmove);
2432
2433  if (SE->hasLoopInvariantBackedgeTakenCount(L))
2434    return runOnCountableLoop(L);
2435  return false;
2436}
2437
2438bool HexagonLoopIdiomRecognizeLegacyPass::runOnLoop(Loop *L,
2439                                                    LPPassManager &LPM) {
2440  if (skipLoop(L))
2441    return false;
2442
2443  auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
2444  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
2445  auto *LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
2446  auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(
2447      *L->getHeader()->getParent());
2448  auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
2449  return HexagonLoopIdiomRecognize(AA, DT, LF, TLI, SE).run(L);
2450}
2451
2452Pass *llvm::createHexagonLoopIdiomPass() {
2453  return new HexagonLoopIdiomRecognizeLegacyPass();
2454}
2455
2456PreservedAnalyses
2457HexagonLoopIdiomRecognitionPass::run(Loop &L, LoopAnalysisManager &AM,
2458                                     LoopStandardAnalysisResults &AR,
2459                                     LPMUpdater &U) {
2460  return HexagonLoopIdiomRecognize(&AR.AA, &AR.DT, &AR.LI, &AR.TLI, &AR.SE)
2461                 .run(&L)
2462             ? getLoopPassPreservedAnalyses()
2463             : PreservedAnalyses::all();
2464}
2465