1//===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass reassociates n-ary add expressions and eliminates the redundancy
10// exposed by the reassociation.
11//
12// A motivating example:
13//
14//   void foo(int a, int b) {
15//     bar(a + b);
16//     bar((a + 2) + b);
17//   }
18//
19// An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify
20// the above code to
21//
22//   int t = a + b;
23//   bar(t);
24//   bar(t + 2);
25//
26// However, the Reassociate pass is unable to do that because it processes each
27// instruction individually and believes (a + 2) + b is the best form according
28// to its rank system.
29//
30// To address this limitation, NaryReassociate reassociates an expression in a
31// form that reuses existing instructions. As a result, NaryReassociate can
32// reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that
33// (a + b) is computed before.
34//
35// NaryReassociate works as follows. For every instruction in the form of (a +
36// b) + c, it checks whether a + c or b + c is already computed by a dominating
37// instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b +
38// c) + a and removes the redundancy accordingly. To efficiently look up whether
39// an expression is computed before, we store each instruction seen and its SCEV
40// into an SCEV-to-instruction map.
41//
42// Although the algorithm pattern-matches only ternary additions, it
43// automatically handles many >3-ary expressions by walking through the function
44// in the depth-first order. For example, given
45//
46//   (a + c) + d
47//   ((a + b) + c) + d
48//
49// NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites
50// ((a + c) + b) + d into ((a + c) + d) + b.
51//
52// Finally, the above dominator-based algorithm may need to be run multiple
53// iterations before emitting optimal code. One source of this need is that we
54// only split an operand when it is used only once. The above algorithm can
55// eliminate an instruction and decrease the usage count of its operands. As a
56// result, an instruction that previously had multiple uses may become a
57// single-use instruction and thus eligible for split consideration. For
58// example,
59//
60//   ac = a + c
61//   ab = a + b
62//   abc = ab + c
63//   ab2 = ab + b
64//   ab2c = ab2 + c
65//
66// In the first iteration, we cannot reassociate abc to ac+b because ab is used
67// twice. However, we can reassociate ab2c to abc+b in the first iteration. As a
68// result, ab2 becomes dead and ab will be used only once in the second
69// iteration.
70//
71// Limitations and TODO items:
72//
73// 1) We only considers n-ary adds and muls for now. This should be extended
74// and generalized.
75//
76//===----------------------------------------------------------------------===//
77
78#include "llvm/Transforms/Scalar/NaryReassociate.h"
79#include "llvm/ADT/DepthFirstIterator.h"
80#include "llvm/ADT/SmallVector.h"
81#include "llvm/Analysis/AssumptionCache.h"
82#include "llvm/Analysis/ScalarEvolution.h"
83#include "llvm/Analysis/ScalarEvolutionExpressions.h"
84#include "llvm/Analysis/TargetLibraryInfo.h"
85#include "llvm/Analysis/TargetTransformInfo.h"
86#include "llvm/Analysis/ValueTracking.h"
87#include "llvm/IR/BasicBlock.h"
88#include "llvm/IR/Constants.h"
89#include "llvm/IR/DataLayout.h"
90#include "llvm/IR/DerivedTypes.h"
91#include "llvm/IR/Dominators.h"
92#include "llvm/IR/Function.h"
93#include "llvm/IR/GetElementPtrTypeIterator.h"
94#include "llvm/IR/IRBuilder.h"
95#include "llvm/IR/InstrTypes.h"
96#include "llvm/IR/Instruction.h"
97#include "llvm/IR/Instructions.h"
98#include "llvm/IR/Module.h"
99#include "llvm/IR/Operator.h"
100#include "llvm/IR/PatternMatch.h"
101#include "llvm/IR/Type.h"
102#include "llvm/IR/Value.h"
103#include "llvm/IR/ValueHandle.h"
104#include "llvm/InitializePasses.h"
105#include "llvm/Pass.h"
106#include "llvm/Support/Casting.h"
107#include "llvm/Support/ErrorHandling.h"
108#include "llvm/Transforms/Scalar.h"
109#include "llvm/Transforms/Utils/Local.h"
110#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
111#include <cassert>
112#include <cstdint>
113
114using namespace llvm;
115using namespace PatternMatch;
116
117#define DEBUG_TYPE "nary-reassociate"
118
119namespace {
120
121class NaryReassociateLegacyPass : public FunctionPass {
122public:
123  static char ID;
124
125  NaryReassociateLegacyPass() : FunctionPass(ID) {
126    initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry());
127  }
128
129  bool doInitialization(Module &M) override {
130    return false;
131  }
132
133  bool runOnFunction(Function &F) override;
134
135  void getAnalysisUsage(AnalysisUsage &AU) const override {
136    AU.addPreserved<DominatorTreeWrapperPass>();
137    AU.addPreserved<ScalarEvolutionWrapperPass>();
138    AU.addPreserved<TargetLibraryInfoWrapperPass>();
139    AU.addRequired<AssumptionCacheTracker>();
140    AU.addRequired<DominatorTreeWrapperPass>();
141    AU.addRequired<ScalarEvolutionWrapperPass>();
142    AU.addRequired<TargetLibraryInfoWrapperPass>();
143    AU.addRequired<TargetTransformInfoWrapperPass>();
144    AU.setPreservesCFG();
145  }
146
147private:
148  NaryReassociatePass Impl;
149};
150
151} // end anonymous namespace
152
153char NaryReassociateLegacyPass::ID = 0;
154
155INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate",
156                      "Nary reassociation", false, false)
157INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
158INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
159INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
160INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
161INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
162INITIALIZE_PASS_END(NaryReassociateLegacyPass, "nary-reassociate",
163                    "Nary reassociation", false, false)
164
165FunctionPass *llvm::createNaryReassociatePass() {
166  return new NaryReassociateLegacyPass();
167}
168
169bool NaryReassociateLegacyPass::runOnFunction(Function &F) {
170  if (skipFunction(F))
171    return false;
172
173  auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
174  auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
175  auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
176  auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
177  auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
178
179  return Impl.runImpl(F, AC, DT, SE, TLI, TTI);
180}
181
182PreservedAnalyses NaryReassociatePass::run(Function &F,
183                                           FunctionAnalysisManager &AM) {
184  auto *AC = &AM.getResult<AssumptionAnalysis>(F);
185  auto *DT = &AM.getResult<DominatorTreeAnalysis>(F);
186  auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F);
187  auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F);
188  auto *TTI = &AM.getResult<TargetIRAnalysis>(F);
189
190  if (!runImpl(F, AC, DT, SE, TLI, TTI))
191    return PreservedAnalyses::all();
192
193  PreservedAnalyses PA;
194  PA.preserveSet<CFGAnalyses>();
195  PA.preserve<ScalarEvolutionAnalysis>();
196  return PA;
197}
198
199bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_,
200                                  DominatorTree *DT_, ScalarEvolution *SE_,
201                                  TargetLibraryInfo *TLI_,
202                                  TargetTransformInfo *TTI_) {
203  AC = AC_;
204  DT = DT_;
205  SE = SE_;
206  TLI = TLI_;
207  TTI = TTI_;
208  DL = &F.getParent()->getDataLayout();
209
210  bool Changed = false, ChangedInThisIteration;
211  do {
212    ChangedInThisIteration = doOneIteration(F);
213    Changed |= ChangedInThisIteration;
214  } while (ChangedInThisIteration);
215  return Changed;
216}
217
218bool NaryReassociatePass::doOneIteration(Function &F) {
219  bool Changed = false;
220  SeenExprs.clear();
221  // Process the basic blocks in a depth first traversal of the dominator
222  // tree. This order ensures that all bases of a candidate are in Candidates
223  // when we process it.
224  SmallVector<WeakTrackingVH, 16> DeadInsts;
225  for (const auto Node : depth_first(DT)) {
226    BasicBlock *BB = Node->getBlock();
227    for (Instruction &OrigI : *BB) {
228      const SCEV *OrigSCEV = nullptr;
229      if (Instruction *NewI = tryReassociate(&OrigI, OrigSCEV)) {
230        Changed = true;
231        OrigI.replaceAllUsesWith(NewI);
232
233        // Add 'OrigI' to the list of dead instructions.
234        DeadInsts.push_back(WeakTrackingVH(&OrigI));
235        // Add the rewritten instruction to SeenExprs; the original
236        // instruction is deleted.
237        const SCEV *NewSCEV = SE->getSCEV(NewI);
238        SeenExprs[NewSCEV].push_back(WeakTrackingVH(NewI));
239
240        // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I)
241        // is equivalent to I. However, ScalarEvolution::getSCEV may
242        // weaken nsw causing NewSCEV not to equal OldSCEV. For example,
243        // suppose we reassociate
244        //   I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4
245        // to
246        //   NewI = &a[sext(i)] + sext(j).
247        //
248        // ScalarEvolution computes
249        //   getSCEV(I)    = a + 4 * sext(i + j)
250        //   getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j)
251        // which are different SCEVs.
252        //
253        // To alleviate this issue of ScalarEvolution not always capturing
254        // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can
255        // map both SCEV before and after tryReassociate(I) to I.
256        //
257        // This improvement is exercised in @reassociate_gep_nsw in
258        // nary-gep.ll.
259        if (NewSCEV != OrigSCEV)
260          SeenExprs[OrigSCEV].push_back(WeakTrackingVH(NewI));
261      } else if (OrigSCEV)
262        SeenExprs[OrigSCEV].push_back(WeakTrackingVH(&OrigI));
263    }
264  }
265  // Delete all dead instructions from 'DeadInsts'.
266  // Please note ScalarEvolution is updated along the way.
267  RecursivelyDeleteTriviallyDeadInstructionsPermissive(
268      DeadInsts, TLI, nullptr, [this](Value *V) { SE->forgetValue(V); });
269
270  return Changed;
271}
272
273template <typename PredT>
274Instruction *
275NaryReassociatePass::matchAndReassociateMinOrMax(Instruction *I,
276                                                 const SCEV *&OrigSCEV) {
277  Value *LHS = nullptr;
278  Value *RHS = nullptr;
279
280  auto MinMaxMatcher =
281      MaxMin_match<ICmpInst, bind_ty<Value>, bind_ty<Value>, PredT>(
282          m_Value(LHS), m_Value(RHS));
283  if (match(I, MinMaxMatcher)) {
284    OrigSCEV = SE->getSCEV(I);
285    return dyn_cast_or_null<Instruction>(
286        tryReassociateMinOrMax(I, MinMaxMatcher, LHS, RHS));
287  }
288  return nullptr;
289}
290
291Instruction *NaryReassociatePass::tryReassociate(Instruction * I,
292                                                 const SCEV *&OrigSCEV) {
293
294  if (!SE->isSCEVable(I->getType()))
295    return nullptr;
296
297  switch (I->getOpcode()) {
298  case Instruction::Add:
299  case Instruction::Mul:
300    OrigSCEV = SE->getSCEV(I);
301    return tryReassociateBinaryOp(cast<BinaryOperator>(I));
302  case Instruction::GetElementPtr:
303    OrigSCEV = SE->getSCEV(I);
304    return tryReassociateGEP(cast<GetElementPtrInst>(I));
305  default:
306    break;
307  }
308
309  // Try to match signed/unsigned Min/Max.
310  Instruction *ResI = nullptr;
311  // TODO: Currently min/max reassociation is restricted to integer types only
312  // due to use of SCEVExpander which my introduce incompatible forms of min/max
313  // for pointer types.
314  if (I->getType()->isIntegerTy())
315    if ((ResI = matchAndReassociateMinOrMax<umin_pred_ty>(I, OrigSCEV)) ||
316        (ResI = matchAndReassociateMinOrMax<smin_pred_ty>(I, OrigSCEV)) ||
317        (ResI = matchAndReassociateMinOrMax<umax_pred_ty>(I, OrigSCEV)) ||
318        (ResI = matchAndReassociateMinOrMax<smax_pred_ty>(I, OrigSCEV)))
319      return ResI;
320
321  return nullptr;
322}
323
324static bool isGEPFoldable(GetElementPtrInst *GEP,
325                          const TargetTransformInfo *TTI) {
326  SmallVector<const Value *, 4> Indices(GEP->indices());
327  return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(),
328                         Indices) == TargetTransformInfo::TCC_Free;
329}
330
331Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) {
332  // Not worth reassociating GEP if it is foldable.
333  if (isGEPFoldable(GEP, TTI))
334    return nullptr;
335
336  gep_type_iterator GTI = gep_type_begin(*GEP);
337  for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) {
338    if (GTI.isSequential()) {
339      if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1,
340                                                  GTI.getIndexedType())) {
341        return NewGEP;
342      }
343    }
344  }
345  return nullptr;
346}
347
348bool NaryReassociatePass::requiresSignExtension(Value *Index,
349                                                GetElementPtrInst *GEP) {
350  unsigned PointerSizeInBits =
351      DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace());
352  return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits;
353}
354
355GetElementPtrInst *
356NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
357                                              unsigned I, Type *IndexedType) {
358  Value *IndexToSplit = GEP->getOperand(I + 1);
359  if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) {
360    IndexToSplit = SExt->getOperand(0);
361  } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) {
362    // zext can be treated as sext if the source is non-negative.
363    if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT))
364      IndexToSplit = ZExt->getOperand(0);
365  }
366
367  if (AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) {
368    // If the I-th index needs sext and the underlying add is not equipped with
369    // nsw, we cannot split the add because
370    //   sext(LHS + RHS) != sext(LHS) + sext(RHS).
371    if (requiresSignExtension(IndexToSplit, GEP) &&
372        computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) !=
373            OverflowResult::NeverOverflows)
374      return nullptr;
375
376    Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1);
377    // IndexToSplit = LHS + RHS.
378    if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType))
379      return NewGEP;
380    // Symmetrically, try IndexToSplit = RHS + LHS.
381    if (LHS != RHS) {
382      if (auto *NewGEP =
383              tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType))
384        return NewGEP;
385    }
386  }
387  return nullptr;
388}
389
390GetElementPtrInst *
391NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP,
392                                              unsigned I, Value *LHS,
393                                              Value *RHS, Type *IndexedType) {
394  // Look for GEP's closest dominator that has the same SCEV as GEP except that
395  // the I-th index is replaced with LHS.
396  SmallVector<const SCEV *, 4> IndexExprs;
397  for (Use &Index : GEP->indices())
398    IndexExprs.push_back(SE->getSCEV(Index));
399  // Replace the I-th index with LHS.
400  IndexExprs[I] = SE->getSCEV(LHS);
401  if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) &&
402      DL->getTypeSizeInBits(LHS->getType()).getFixedSize() <
403          DL->getTypeSizeInBits(GEP->getOperand(I)->getType()).getFixedSize()) {
404    // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to
405    // zext if the source operand is proved non-negative. We should do that
406    // consistently so that CandidateExpr more likely appears before. See
407    // @reassociate_gep_assume for an example of this canonicalization.
408    IndexExprs[I] =
409        SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType());
410  }
411  const SCEV *CandidateExpr = SE->getGEPExpr(cast<GEPOperator>(GEP),
412                                             IndexExprs);
413
414  Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP);
415  if (Candidate == nullptr)
416    return nullptr;
417
418  IRBuilder<> Builder(GEP);
419  // Candidate does not necessarily have the same pointer type as GEP. Use
420  // bitcast or pointer cast to make sure they have the same type, so that the
421  // later RAUW doesn't complain.
422  Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType());
423  assert(Candidate->getType() == GEP->getType());
424
425  // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType)
426  uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType);
427  Type *ElementType = GEP->getResultElementType();
428  uint64_t ElementSize = DL->getTypeAllocSize(ElementType);
429  // Another less rare case: because I is not necessarily the last index of the
430  // GEP, the size of the type at the I-th index (IndexedSize) is not
431  // necessarily divisible by ElementSize. For example,
432  //
433  // #pragma pack(1)
434  // struct S {
435  //   int a[3];
436  //   int64 b[8];
437  // };
438  // #pragma pack()
439  //
440  // sizeof(S) = 100 is indivisible by sizeof(int64) = 8.
441  //
442  // TODO: bail out on this case for now. We could emit uglygep.
443  if (IndexedSize % ElementSize != 0)
444    return nullptr;
445
446  // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0])));
447  Type *IntPtrTy = DL->getIntPtrType(GEP->getType());
448  if (RHS->getType() != IntPtrTy)
449    RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy);
450  if (IndexedSize != ElementSize) {
451    RHS = Builder.CreateMul(
452        RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize));
453  }
454  GetElementPtrInst *NewGEP = cast<GetElementPtrInst>(
455      Builder.CreateGEP(GEP->getResultElementType(), Candidate, RHS));
456  NewGEP->setIsInBounds(GEP->isInBounds());
457  NewGEP->takeName(GEP);
458  return NewGEP;
459}
460
461Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) {
462  Value *LHS = I->getOperand(0), *RHS = I->getOperand(1);
463  // There is no need to reassociate 0.
464  if (SE->getSCEV(I)->isZero())
465    return nullptr;
466  if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I))
467    return NewI;
468  if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I))
469    return NewI;
470  return nullptr;
471}
472
473Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS,
474                                                         BinaryOperator *I) {
475  Value *A = nullptr, *B = nullptr;
476  // To be conservative, we reassociate I only when it is the only user of (A op
477  // B).
478  if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) {
479    // I = (A op B) op RHS
480    //   = (A op RHS) op B or (B op RHS) op A
481    const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B);
482    const SCEV *RHSExpr = SE->getSCEV(RHS);
483    if (BExpr != RHSExpr) {
484      if (auto *NewI =
485              tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I))
486        return NewI;
487    }
488    if (AExpr != RHSExpr) {
489      if (auto *NewI =
490              tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I))
491        return NewI;
492    }
493  }
494  return nullptr;
495}
496
497Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr,
498                                                          Value *RHS,
499                                                          BinaryOperator *I) {
500  // Look for the closest dominator LHS of I that computes LHSExpr, and replace
501  // I with LHS op RHS.
502  auto *LHS = findClosestMatchingDominator(LHSExpr, I);
503  if (LHS == nullptr)
504    return nullptr;
505
506  Instruction *NewI = nullptr;
507  switch (I->getOpcode()) {
508  case Instruction::Add:
509    NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I);
510    break;
511  case Instruction::Mul:
512    NewI = BinaryOperator::CreateMul(LHS, RHS, "", I);
513    break;
514  default:
515    llvm_unreachable("Unexpected instruction.");
516  }
517  NewI->takeName(I);
518  return NewI;
519}
520
521bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V,
522                                         Value *&Op1, Value *&Op2) {
523  switch (I->getOpcode()) {
524  case Instruction::Add:
525    return match(V, m_Add(m_Value(Op1), m_Value(Op2)));
526  case Instruction::Mul:
527    return match(V, m_Mul(m_Value(Op1), m_Value(Op2)));
528  default:
529    llvm_unreachable("Unexpected instruction.");
530  }
531  return false;
532}
533
534const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I,
535                                               const SCEV *LHS,
536                                               const SCEV *RHS) {
537  switch (I->getOpcode()) {
538  case Instruction::Add:
539    return SE->getAddExpr(LHS, RHS);
540  case Instruction::Mul:
541    return SE->getMulExpr(LHS, RHS);
542  default:
543    llvm_unreachable("Unexpected instruction.");
544  }
545  return nullptr;
546}
547
548Instruction *
549NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr,
550                                                  Instruction *Dominatee) {
551  auto Pos = SeenExprs.find(CandidateExpr);
552  if (Pos == SeenExprs.end())
553    return nullptr;
554
555  auto &Candidates = Pos->second;
556  // Because we process the basic blocks in pre-order of the dominator tree, a
557  // candidate that doesn't dominate the current instruction won't dominate any
558  // future instruction either. Therefore, we pop it out of the stack. This
559  // optimization makes the algorithm O(n).
560  while (!Candidates.empty()) {
561    // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's
562    // removed
563    // during rewriting.
564    if (Value *Candidate = Candidates.back()) {
565      Instruction *CandidateInstruction = cast<Instruction>(Candidate);
566      if (DT->dominates(CandidateInstruction, Dominatee))
567        return CandidateInstruction;
568    }
569    Candidates.pop_back();
570  }
571  return nullptr;
572}
573
574template <typename MaxMinT> static SCEVTypes convertToSCEVype(MaxMinT &MM) {
575  if (std::is_same<smax_pred_ty, typename MaxMinT::PredType>::value)
576    return scSMaxExpr;
577  else if (std::is_same<umax_pred_ty, typename MaxMinT::PredType>::value)
578    return scUMaxExpr;
579  else if (std::is_same<smin_pred_ty, typename MaxMinT::PredType>::value)
580    return scSMinExpr;
581  else if (std::is_same<umin_pred_ty, typename MaxMinT::PredType>::value)
582    return scUMinExpr;
583
584  llvm_unreachable("Can't convert MinMax pattern to SCEV type");
585  return scUnknown;
586}
587
588// Parameters:
589//  I - instruction matched by MaxMinMatch matcher
590//  MaxMinMatch - min/max idiom matcher
591//  LHS - first operand of I
592//  RHS - second operand of I
593template <typename MaxMinT>
594Value *NaryReassociatePass::tryReassociateMinOrMax(Instruction *I,
595                                                   MaxMinT MaxMinMatch,
596                                                   Value *LHS, Value *RHS) {
597  Value *A = nullptr, *B = nullptr;
598  MaxMinT m_MaxMin(m_Value(A), m_Value(B));
599  for (unsigned int i = 0; i < 2; ++i) {
600    if (!LHS->hasNUsesOrMore(3) && match(LHS, m_MaxMin)) {
601      const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B);
602      const SCEV *RHSExpr = SE->getSCEV(RHS);
603      for (unsigned int j = 0; j < 2; ++j) {
604        if (j == 0) {
605          if (BExpr == RHSExpr)
606            continue;
607          // Transform 'I = (A op B) op RHS' to 'I = (A op RHS) op B' on the
608          // first iteration.
609          std::swap(BExpr, RHSExpr);
610        } else {
611          if (AExpr == RHSExpr)
612            continue;
613          // Transform 'I = (A op RHS) op B' 'I = (B op RHS) op A' on the second
614          // iteration.
615          std::swap(AExpr, RHSExpr);
616        }
617
618        // The optimization is profitable only if LHS can be removed in the end.
619        // In other words LHS should be used (directly or indirectly) by I only.
620        if (llvm::any_of(LHS->users(), [&](auto *U) {
621              return U != I && !(U->hasOneUser() && *U->users().begin() == I);
622            }))
623          continue;
624
625        SCEVExpander Expander(*SE, *DL, "nary-reassociate");
626        SmallVector<const SCEV *, 2> Ops1{ BExpr, AExpr };
627        const SCEVTypes SCEVType = convertToSCEVype(m_MaxMin);
628        const SCEV *R1Expr = SE->getMinMaxExpr(SCEVType, Ops1);
629
630        Instruction *R1MinMax = findClosestMatchingDominator(R1Expr, I);
631
632        if (!R1MinMax)
633          continue;
634
635        LLVM_DEBUG(dbgs() << "NARY: Found common sub-expr: " << *R1MinMax
636                          << "\n");
637
638        R1Expr = SE->getUnknown(R1MinMax);
639        SmallVector<const SCEV *, 2> Ops2{ RHSExpr, R1Expr };
640        const SCEV *R2Expr = SE->getMinMaxExpr(SCEVType, Ops2);
641
642        Value *NewMinMax = Expander.expandCodeFor(R2Expr, I->getType(), I);
643        NewMinMax->setName(Twine(I->getName()).concat(".nary"));
644
645        LLVM_DEBUG(dbgs() << "NARY: Deleting:  " << *I << "\n"
646                          << "NARY: Inserting: " << *NewMinMax << "\n");
647        return NewMinMax;
648      }
649    }
650    std::swap(LHS, RHS);
651  }
652  return nullptr;
653}
654