1//===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Correlated Value Propagation pass.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h"
14#include "llvm/ADT/DepthFirstIterator.h"
15#include "llvm/ADT/Optional.h"
16#include "llvm/ADT/SmallVector.h"
17#include "llvm/ADT/Statistic.h"
18#include "llvm/Analysis/DomTreeUpdater.h"
19#include "llvm/Analysis/GlobalsModRef.h"
20#include "llvm/Analysis/InstructionSimplify.h"
21#include "llvm/Analysis/LazyValueInfo.h"
22#include "llvm/IR/Attributes.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/CFG.h"
25#include "llvm/IR/CallSite.h"
26#include "llvm/IR/Constant.h"
27#include "llvm/IR/ConstantRange.h"
28#include "llvm/IR/Constants.h"
29#include "llvm/IR/DerivedTypes.h"
30#include "llvm/IR/Function.h"
31#include "llvm/IR/IRBuilder.h"
32#include "llvm/IR/InstrTypes.h"
33#include "llvm/IR/Instruction.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/IR/IntrinsicInst.h"
36#include "llvm/IR/Operator.h"
37#include "llvm/IR/PassManager.h"
38#include "llvm/IR/Type.h"
39#include "llvm/IR/Value.h"
40#include "llvm/InitializePasses.h"
41#include "llvm/Pass.h"
42#include "llvm/Support/Casting.h"
43#include "llvm/Support/CommandLine.h"
44#include "llvm/Support/Debug.h"
45#include "llvm/Support/raw_ostream.h"
46#include "llvm/Transforms/Scalar.h"
47#include "llvm/Transforms/Utils/Local.h"
48#include <cassert>
49#include <utility>
50
51using namespace llvm;
52
53#define DEBUG_TYPE "correlated-value-propagation"
54
55STATISTIC(NumPhis,      "Number of phis propagated");
56STATISTIC(NumPhiCommon, "Number of phis deleted via common incoming value");
57STATISTIC(NumSelects,   "Number of selects propagated");
58STATISTIC(NumMemAccess, "Number of memory access targets propagated");
59STATISTIC(NumCmps,      "Number of comparisons propagated");
60STATISTIC(NumReturns,   "Number of return values propagated");
61STATISTIC(NumDeadCases, "Number of switch cases removed");
62STATISTIC(NumSDivs,     "Number of sdiv converted to udiv");
63STATISTIC(NumUDivs,     "Number of udivs whose width was decreased");
64STATISTIC(NumAShrs,     "Number of ashr converted to lshr");
65STATISTIC(NumSRems,     "Number of srem converted to urem");
66STATISTIC(NumSExt,      "Number of sext converted to zext");
67STATISTIC(NumAnd,       "Number of ands removed");
68STATISTIC(NumNW,        "Number of no-wrap deductions");
69STATISTIC(NumNSW,       "Number of no-signed-wrap deductions");
70STATISTIC(NumNUW,       "Number of no-unsigned-wrap deductions");
71STATISTIC(NumAddNW,     "Number of no-wrap deductions for add");
72STATISTIC(NumAddNSW,    "Number of no-signed-wrap deductions for add");
73STATISTIC(NumAddNUW,    "Number of no-unsigned-wrap deductions for add");
74STATISTIC(NumSubNW,     "Number of no-wrap deductions for sub");
75STATISTIC(NumSubNSW,    "Number of no-signed-wrap deductions for sub");
76STATISTIC(NumSubNUW,    "Number of no-unsigned-wrap deductions for sub");
77STATISTIC(NumMulNW,     "Number of no-wrap deductions for mul");
78STATISTIC(NumMulNSW,    "Number of no-signed-wrap deductions for mul");
79STATISTIC(NumMulNUW,    "Number of no-unsigned-wrap deductions for mul");
80STATISTIC(NumShlNW,     "Number of no-wrap deductions for shl");
81STATISTIC(NumShlNSW,    "Number of no-signed-wrap deductions for shl");
82STATISTIC(NumShlNUW,    "Number of no-unsigned-wrap deductions for shl");
83STATISTIC(NumOverflows, "Number of overflow checks removed");
84STATISTIC(NumSaturating,
85    "Number of saturating arithmetics converted to normal arithmetics");
86
87static cl::opt<bool> DontAddNoWrapFlags("cvp-dont-add-nowrap-flags", cl::init(false));
88
89namespace {
90
91  class CorrelatedValuePropagation : public FunctionPass {
92  public:
93    static char ID;
94
95    CorrelatedValuePropagation(): FunctionPass(ID) {
96     initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry());
97    }
98
99    bool runOnFunction(Function &F) override;
100
101    void getAnalysisUsage(AnalysisUsage &AU) const override {
102      AU.addRequired<DominatorTreeWrapperPass>();
103      AU.addRequired<LazyValueInfoWrapperPass>();
104      AU.addPreserved<GlobalsAAWrapperPass>();
105      AU.addPreserved<DominatorTreeWrapperPass>();
106      AU.addPreserved<LazyValueInfoWrapperPass>();
107    }
108  };
109
110} // end anonymous namespace
111
112char CorrelatedValuePropagation::ID = 0;
113
114INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation",
115                "Value Propagation", false, false)
116INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
117INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass)
118INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation",
119                "Value Propagation", false, false)
120
121// Public interface to the Value Propagation pass
122Pass *llvm::createCorrelatedValuePropagationPass() {
123  return new CorrelatedValuePropagation();
124}
125
126static bool processSelect(SelectInst *S, LazyValueInfo *LVI) {
127  if (S->getType()->isVectorTy()) return false;
128  if (isa<Constant>(S->getOperand(0))) return false;
129
130  Constant *C = LVI->getConstant(S->getCondition(), S->getParent(), S);
131  if (!C) return false;
132
133  ConstantInt *CI = dyn_cast<ConstantInt>(C);
134  if (!CI) return false;
135
136  Value *ReplaceWith = S->getTrueValue();
137  Value *Other = S->getFalseValue();
138  if (!CI->isOne()) std::swap(ReplaceWith, Other);
139  if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType());
140
141  S->replaceAllUsesWith(ReplaceWith);
142  S->eraseFromParent();
143
144  ++NumSelects;
145
146  return true;
147}
148
149/// Try to simplify a phi with constant incoming values that match the edge
150/// values of a non-constant value on all other edges:
151/// bb0:
152///   %isnull = icmp eq i8* %x, null
153///   br i1 %isnull, label %bb2, label %bb1
154/// bb1:
155///   br label %bb2
156/// bb2:
157///   %r = phi i8* [ %x, %bb1 ], [ null, %bb0 ]
158/// -->
159///   %r = %x
160static bool simplifyCommonValuePhi(PHINode *P, LazyValueInfo *LVI,
161                                   DominatorTree *DT) {
162  // Collect incoming constants and initialize possible common value.
163  SmallVector<std::pair<Constant *, unsigned>, 4> IncomingConstants;
164  Value *CommonValue = nullptr;
165  for (unsigned i = 0, e = P->getNumIncomingValues(); i != e; ++i) {
166    Value *Incoming = P->getIncomingValue(i);
167    if (auto *IncomingConstant = dyn_cast<Constant>(Incoming)) {
168      IncomingConstants.push_back(std::make_pair(IncomingConstant, i));
169    } else if (!CommonValue) {
170      // The potential common value is initialized to the first non-constant.
171      CommonValue = Incoming;
172    } else if (Incoming != CommonValue) {
173      // There can be only one non-constant common value.
174      return false;
175    }
176  }
177
178  if (!CommonValue || IncomingConstants.empty())
179    return false;
180
181  // The common value must be valid in all incoming blocks.
182  BasicBlock *ToBB = P->getParent();
183  if (auto *CommonInst = dyn_cast<Instruction>(CommonValue))
184    if (!DT->dominates(CommonInst, ToBB))
185      return false;
186
187  // We have a phi with exactly 1 variable incoming value and 1 or more constant
188  // incoming values. See if all constant incoming values can be mapped back to
189  // the same incoming variable value.
190  for (auto &IncomingConstant : IncomingConstants) {
191    Constant *C = IncomingConstant.first;
192    BasicBlock *IncomingBB = P->getIncomingBlock(IncomingConstant.second);
193    if (C != LVI->getConstantOnEdge(CommonValue, IncomingBB, ToBB, P))
194      return false;
195  }
196
197  // All constant incoming values map to the same variable along the incoming
198  // edges of the phi. The phi is unnecessary. However, we must drop all
199  // poison-generating flags to ensure that no poison is propagated to the phi
200  // location by performing this substitution.
201  // Warning: If the underlying analysis changes, this may not be enough to
202  //          guarantee that poison is not propagated.
203  // TODO: We may be able to re-infer flags by re-analyzing the instruction.
204  if (auto *CommonInst = dyn_cast<Instruction>(CommonValue))
205    CommonInst->dropPoisonGeneratingFlags();
206  P->replaceAllUsesWith(CommonValue);
207  P->eraseFromParent();
208  ++NumPhiCommon;
209  return true;
210}
211
212static bool processPHI(PHINode *P, LazyValueInfo *LVI, DominatorTree *DT,
213                       const SimplifyQuery &SQ) {
214  bool Changed = false;
215
216  BasicBlock *BB = P->getParent();
217  for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) {
218    Value *Incoming = P->getIncomingValue(i);
219    if (isa<Constant>(Incoming)) continue;
220
221    Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P);
222
223    // Look if the incoming value is a select with a scalar condition for which
224    // LVI can tells us the value. In that case replace the incoming value with
225    // the appropriate value of the select. This often allows us to remove the
226    // select later.
227    if (!V) {
228      SelectInst *SI = dyn_cast<SelectInst>(Incoming);
229      if (!SI) continue;
230
231      Value *Condition = SI->getCondition();
232      if (!Condition->getType()->isVectorTy()) {
233        if (Constant *C = LVI->getConstantOnEdge(
234                Condition, P->getIncomingBlock(i), BB, P)) {
235          if (C->isOneValue()) {
236            V = SI->getTrueValue();
237          } else if (C->isZeroValue()) {
238            V = SI->getFalseValue();
239          }
240          // Once LVI learns to handle vector types, we could also add support
241          // for vector type constants that are not all zeroes or all ones.
242        }
243      }
244
245      // Look if the select has a constant but LVI tells us that the incoming
246      // value can never be that constant. In that case replace the incoming
247      // value with the other value of the select. This often allows us to
248      // remove the select later.
249      if (!V) {
250        Constant *C = dyn_cast<Constant>(SI->getFalseValue());
251        if (!C) continue;
252
253        if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C,
254              P->getIncomingBlock(i), BB, P) !=
255            LazyValueInfo::False)
256          continue;
257        V = SI->getTrueValue();
258      }
259
260      LLVM_DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n');
261    }
262
263    P->setIncomingValue(i, V);
264    Changed = true;
265  }
266
267  if (Value *V = SimplifyInstruction(P, SQ)) {
268    P->replaceAllUsesWith(V);
269    P->eraseFromParent();
270    Changed = true;
271  }
272
273  if (!Changed)
274    Changed = simplifyCommonValuePhi(P, LVI, DT);
275
276  if (Changed)
277    ++NumPhis;
278
279  return Changed;
280}
281
282static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) {
283  Value *Pointer = nullptr;
284  if (LoadInst *L = dyn_cast<LoadInst>(I))
285    Pointer = L->getPointerOperand();
286  else
287    Pointer = cast<StoreInst>(I)->getPointerOperand();
288
289  if (isa<Constant>(Pointer)) return false;
290
291  Constant *C = LVI->getConstant(Pointer, I->getParent(), I);
292  if (!C) return false;
293
294  ++NumMemAccess;
295  I->replaceUsesOfWith(Pointer, C);
296  return true;
297}
298
299/// See if LazyValueInfo's ability to exploit edge conditions or range
300/// information is sufficient to prove this comparison. Even for local
301/// conditions, this can sometimes prove conditions instcombine can't by
302/// exploiting range information.
303static bool processCmp(CmpInst *Cmp, LazyValueInfo *LVI) {
304  Value *Op0 = Cmp->getOperand(0);
305  auto *C = dyn_cast<Constant>(Cmp->getOperand(1));
306  if (!C)
307    return false;
308
309  // As a policy choice, we choose not to waste compile time on anything where
310  // the comparison is testing local values.  While LVI can sometimes reason
311  // about such cases, it's not its primary purpose.  We do make sure to do
312  // the block local query for uses from terminator instructions, but that's
313  // handled in the code for each terminator.
314  auto *I = dyn_cast<Instruction>(Op0);
315  if (I && I->getParent() == Cmp->getParent())
316    return false;
317
318  LazyValueInfo::Tristate Result =
319      LVI->getPredicateAt(Cmp->getPredicate(), Op0, C, Cmp);
320  if (Result == LazyValueInfo::Unknown)
321    return false;
322
323  ++NumCmps;
324  Constant *TorF = ConstantInt::get(Type::getInt1Ty(Cmp->getContext()), Result);
325  Cmp->replaceAllUsesWith(TorF);
326  Cmp->eraseFromParent();
327  return true;
328}
329
330/// Simplify a switch instruction by removing cases which can never fire. If the
331/// uselessness of a case could be determined locally then constant propagation
332/// would already have figured it out. Instead, walk the predecessors and
333/// statically evaluate cases based on information available on that edge. Cases
334/// that cannot fire no matter what the incoming edge can safely be removed. If
335/// a case fires on every incoming edge then the entire switch can be removed
336/// and replaced with a branch to the case destination.
337static bool processSwitch(SwitchInst *I, LazyValueInfo *LVI,
338                          DominatorTree *DT) {
339  DomTreeUpdater DTU(*DT, DomTreeUpdater::UpdateStrategy::Lazy);
340  Value *Cond = I->getCondition();
341  BasicBlock *BB = I->getParent();
342
343  // If the condition was defined in same block as the switch then LazyValueInfo
344  // currently won't say anything useful about it, though in theory it could.
345  if (isa<Instruction>(Cond) && cast<Instruction>(Cond)->getParent() == BB)
346    return false;
347
348  // If the switch is unreachable then trying to improve it is a waste of time.
349  pred_iterator PB = pred_begin(BB), PE = pred_end(BB);
350  if (PB == PE) return false;
351
352  // Analyse each switch case in turn.
353  bool Changed = false;
354  DenseMap<BasicBlock*, int> SuccessorsCount;
355  for (auto *Succ : successors(BB))
356    SuccessorsCount[Succ]++;
357
358  { // Scope for SwitchInstProfUpdateWrapper. It must not live during
359    // ConstantFoldTerminator() as the underlying SwitchInst can be changed.
360    SwitchInstProfUpdateWrapper SI(*I);
361
362    for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) {
363      ConstantInt *Case = CI->getCaseValue();
364
365      // Check to see if the switch condition is equal to/not equal to the case
366      // value on every incoming edge, equal/not equal being the same each time.
367      LazyValueInfo::Tristate State = LazyValueInfo::Unknown;
368      for (pred_iterator PI = PB; PI != PE; ++PI) {
369        // Is the switch condition equal to the case value?
370        LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ,
371                                                                Cond, Case, *PI,
372                                                                BB, SI);
373        // Give up on this case if nothing is known.
374        if (Value == LazyValueInfo::Unknown) {
375          State = LazyValueInfo::Unknown;
376          break;
377        }
378
379        // If this was the first edge to be visited, record that all other edges
380        // need to give the same result.
381        if (PI == PB) {
382          State = Value;
383          continue;
384        }
385
386        // If this case is known to fire for some edges and known not to fire for
387        // others then there is nothing we can do - give up.
388        if (Value != State) {
389          State = LazyValueInfo::Unknown;
390          break;
391        }
392      }
393
394      if (State == LazyValueInfo::False) {
395        // This case never fires - remove it.
396        BasicBlock *Succ = CI->getCaseSuccessor();
397        Succ->removePredecessor(BB);
398        CI = SI.removeCase(CI);
399        CE = SI->case_end();
400
401        // The condition can be modified by removePredecessor's PHI simplification
402        // logic.
403        Cond = SI->getCondition();
404
405        ++NumDeadCases;
406        Changed = true;
407        if (--SuccessorsCount[Succ] == 0)
408          DTU.applyUpdatesPermissive({{DominatorTree::Delete, BB, Succ}});
409        continue;
410      }
411      if (State == LazyValueInfo::True) {
412        // This case always fires.  Arrange for the switch to be turned into an
413        // unconditional branch by replacing the switch condition with the case
414        // value.
415        SI->setCondition(Case);
416        NumDeadCases += SI->getNumCases();
417        Changed = true;
418        break;
419      }
420
421      // Increment the case iterator since we didn't delete it.
422      ++CI;
423    }
424  }
425
426  if (Changed)
427    // If the switch has been simplified to the point where it can be replaced
428    // by a branch then do so now.
429    ConstantFoldTerminator(BB, /*DeleteDeadConditions = */ false,
430                           /*TLI = */ nullptr, &DTU);
431  return Changed;
432}
433
434// See if we can prove that the given binary op intrinsic will not overflow.
435static bool willNotOverflow(BinaryOpIntrinsic *BO, LazyValueInfo *LVI) {
436  ConstantRange LRange = LVI->getConstantRange(
437      BO->getLHS(), BO->getParent(), BO);
438  ConstantRange RRange = LVI->getConstantRange(
439      BO->getRHS(), BO->getParent(), BO);
440  ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
441      BO->getBinaryOp(), RRange, BO->getNoWrapKind());
442  return NWRegion.contains(LRange);
443}
444
445static void setDeducedOverflowingFlags(Value *V, Instruction::BinaryOps Opcode,
446                                       bool NewNSW, bool NewNUW) {
447  Statistic *OpcNW, *OpcNSW, *OpcNUW;
448  switch (Opcode) {
449  case Instruction::Add:
450    OpcNW = &NumAddNW;
451    OpcNSW = &NumAddNSW;
452    OpcNUW = &NumAddNUW;
453    break;
454  case Instruction::Sub:
455    OpcNW = &NumSubNW;
456    OpcNSW = &NumSubNSW;
457    OpcNUW = &NumSubNUW;
458    break;
459  case Instruction::Mul:
460    OpcNW = &NumMulNW;
461    OpcNSW = &NumMulNSW;
462    OpcNUW = &NumMulNUW;
463    break;
464  case Instruction::Shl:
465    OpcNW = &NumShlNW;
466    OpcNSW = &NumShlNSW;
467    OpcNUW = &NumShlNUW;
468    break;
469  default:
470    llvm_unreachable("Will not be called with other binops");
471  }
472
473  auto *Inst = dyn_cast<Instruction>(V);
474  if (NewNSW) {
475    ++NumNW;
476    ++*OpcNW;
477    ++NumNSW;
478    ++*OpcNSW;
479    if (Inst)
480      Inst->setHasNoSignedWrap();
481  }
482  if (NewNUW) {
483    ++NumNW;
484    ++*OpcNW;
485    ++NumNUW;
486    ++*OpcNUW;
487    if (Inst)
488      Inst->setHasNoUnsignedWrap();
489  }
490}
491
492static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI);
493
494// Rewrite this with.overflow intrinsic as non-overflowing.
495static void processOverflowIntrinsic(WithOverflowInst *WO, LazyValueInfo *LVI) {
496  IRBuilder<> B(WO);
497  Instruction::BinaryOps Opcode = WO->getBinaryOp();
498  bool NSW = WO->isSigned();
499  bool NUW = !WO->isSigned();
500
501  Value *NewOp =
502      B.CreateBinOp(Opcode, WO->getLHS(), WO->getRHS(), WO->getName());
503  setDeducedOverflowingFlags(NewOp, Opcode, NSW, NUW);
504
505  StructType *ST = cast<StructType>(WO->getType());
506  Constant *Struct = ConstantStruct::get(ST,
507      { UndefValue::get(ST->getElementType(0)),
508        ConstantInt::getFalse(ST->getElementType(1)) });
509  Value *NewI = B.CreateInsertValue(Struct, NewOp, 0);
510  WO->replaceAllUsesWith(NewI);
511  WO->eraseFromParent();
512  ++NumOverflows;
513
514  // See if we can infer the other no-wrap too.
515  if (auto *BO = dyn_cast<BinaryOperator>(NewOp))
516    processBinOp(BO, LVI);
517}
518
519static void processSaturatingInst(SaturatingInst *SI, LazyValueInfo *LVI) {
520  Instruction::BinaryOps Opcode = SI->getBinaryOp();
521  bool NSW = SI->isSigned();
522  bool NUW = !SI->isSigned();
523  BinaryOperator *BinOp = BinaryOperator::Create(
524      Opcode, SI->getLHS(), SI->getRHS(), SI->getName(), SI);
525  BinOp->setDebugLoc(SI->getDebugLoc());
526  setDeducedOverflowingFlags(BinOp, Opcode, NSW, NUW);
527
528  SI->replaceAllUsesWith(BinOp);
529  SI->eraseFromParent();
530  ++NumSaturating;
531
532  // See if we can infer the other no-wrap too.
533  if (auto *BO = dyn_cast<BinaryOperator>(BinOp))
534    processBinOp(BO, LVI);
535}
536
537/// Infer nonnull attributes for the arguments at the specified callsite.
538static bool processCallSite(CallSite CS, LazyValueInfo *LVI) {
539  SmallVector<unsigned, 4> ArgNos;
540  unsigned ArgNo = 0;
541
542  if (auto *WO = dyn_cast<WithOverflowInst>(CS.getInstruction())) {
543    if (WO->getLHS()->getType()->isIntegerTy() && willNotOverflow(WO, LVI)) {
544      processOverflowIntrinsic(WO, LVI);
545      return true;
546    }
547  }
548
549  if (auto *SI = dyn_cast<SaturatingInst>(CS.getInstruction())) {
550    if (SI->getType()->isIntegerTy() && willNotOverflow(SI, LVI)) {
551      processSaturatingInst(SI, LVI);
552      return true;
553    }
554  }
555
556  // Deopt bundle operands are intended to capture state with minimal
557  // perturbance of the code otherwise.  If we can find a constant value for
558  // any such operand and remove a use of the original value, that's
559  // desireable since it may allow further optimization of that value (e.g. via
560  // single use rules in instcombine).  Since deopt uses tend to,
561  // idiomatically, appear along rare conditional paths, it's reasonable likely
562  // we may have a conditional fact with which LVI can fold.
563  if (auto DeoptBundle = CS.getOperandBundle(LLVMContext::OB_deopt)) {
564    bool Progress = false;
565    for (const Use &ConstU : DeoptBundle->Inputs) {
566      Use &U = const_cast<Use&>(ConstU);
567      Value *V = U.get();
568      if (V->getType()->isVectorTy()) continue;
569      if (isa<Constant>(V)) continue;
570
571      Constant *C = LVI->getConstant(V, CS.getParent(), CS.getInstruction());
572      if (!C) continue;
573      U.set(C);
574      Progress = true;
575    }
576    if (Progress)
577      return true;
578  }
579
580  for (Value *V : CS.args()) {
581    PointerType *Type = dyn_cast<PointerType>(V->getType());
582    // Try to mark pointer typed parameters as non-null.  We skip the
583    // relatively expensive analysis for constants which are obviously either
584    // null or non-null to start with.
585    if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&
586        !isa<Constant>(V) &&
587        LVI->getPredicateAt(ICmpInst::ICMP_EQ, V,
588                            ConstantPointerNull::get(Type),
589                            CS.getInstruction()) == LazyValueInfo::False)
590      ArgNos.push_back(ArgNo);
591    ArgNo++;
592  }
593
594  assert(ArgNo == CS.arg_size() && "sanity check");
595
596  if (ArgNos.empty())
597    return false;
598
599  AttributeList AS = CS.getAttributes();
600  LLVMContext &Ctx = CS.getInstruction()->getContext();
601  AS = AS.addParamAttribute(Ctx, ArgNos,
602                            Attribute::get(Ctx, Attribute::NonNull));
603  CS.setAttributes(AS);
604
605  return true;
606}
607
608static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) {
609  Constant *Zero = ConstantInt::get(SDI->getType(), 0);
610  for (Value *O : SDI->operands()) {
611    auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, O, Zero, SDI);
612    if (Result != LazyValueInfo::True)
613      return false;
614  }
615  return true;
616}
617
618/// Try to shrink a udiv/urem's width down to the smallest power of two that's
619/// sufficient to contain its operands.
620static bool processUDivOrURem(BinaryOperator *Instr, LazyValueInfo *LVI) {
621  assert(Instr->getOpcode() == Instruction::UDiv ||
622         Instr->getOpcode() == Instruction::URem);
623  if (Instr->getType()->isVectorTy())
624    return false;
625
626  // Find the smallest power of two bitwidth that's sufficient to hold Instr's
627  // operands.
628  auto OrigWidth = Instr->getType()->getIntegerBitWidth();
629  ConstantRange OperandRange(OrigWidth, /*isFullSet=*/false);
630  for (Value *Operand : Instr->operands()) {
631    OperandRange = OperandRange.unionWith(
632        LVI->getConstantRange(Operand, Instr->getParent()));
633  }
634  // Don't shrink below 8 bits wide.
635  unsigned NewWidth = std::max<unsigned>(
636      PowerOf2Ceil(OperandRange.getUnsignedMax().getActiveBits()), 8);
637  // NewWidth might be greater than OrigWidth if OrigWidth is not a power of
638  // two.
639  if (NewWidth >= OrigWidth)
640    return false;
641
642  ++NumUDivs;
643  IRBuilder<> B{Instr};
644  auto *TruncTy = Type::getIntNTy(Instr->getContext(), NewWidth);
645  auto *LHS = B.CreateTruncOrBitCast(Instr->getOperand(0), TruncTy,
646                                     Instr->getName() + ".lhs.trunc");
647  auto *RHS = B.CreateTruncOrBitCast(Instr->getOperand(1), TruncTy,
648                                     Instr->getName() + ".rhs.trunc");
649  auto *BO = B.CreateBinOp(Instr->getOpcode(), LHS, RHS, Instr->getName());
650  auto *Zext = B.CreateZExt(BO, Instr->getType(), Instr->getName() + ".zext");
651  if (auto *BinOp = dyn_cast<BinaryOperator>(BO))
652    if (BinOp->getOpcode() == Instruction::UDiv)
653      BinOp->setIsExact(Instr->isExact());
654
655  Instr->replaceAllUsesWith(Zext);
656  Instr->eraseFromParent();
657  return true;
658}
659
660static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) {
661  if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI))
662    return false;
663
664  ++NumSRems;
665  auto *BO = BinaryOperator::CreateURem(SDI->getOperand(0), SDI->getOperand(1),
666                                        SDI->getName(), SDI);
667  BO->setDebugLoc(SDI->getDebugLoc());
668  SDI->replaceAllUsesWith(BO);
669  SDI->eraseFromParent();
670
671  // Try to process our new urem.
672  processUDivOrURem(BO, LVI);
673
674  return true;
675}
676
677/// See if LazyValueInfo's ability to exploit edge conditions or range
678/// information is sufficient to prove the both operands of this SDiv are
679/// positive.  If this is the case, replace the SDiv with a UDiv. Even for local
680/// conditions, this can sometimes prove conditions instcombine can't by
681/// exploiting range information.
682static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) {
683  if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI))
684    return false;
685
686  ++NumSDivs;
687  auto *BO = BinaryOperator::CreateUDiv(SDI->getOperand(0), SDI->getOperand(1),
688                                        SDI->getName(), SDI);
689  BO->setDebugLoc(SDI->getDebugLoc());
690  BO->setIsExact(SDI->isExact());
691  SDI->replaceAllUsesWith(BO);
692  SDI->eraseFromParent();
693
694  // Try to simplify our new udiv.
695  processUDivOrURem(BO, LVI);
696
697  return true;
698}
699
700static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) {
701  if (SDI->getType()->isVectorTy())
702    return false;
703
704  Constant *Zero = ConstantInt::get(SDI->getType(), 0);
705  if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, SDI->getOperand(0), Zero, SDI) !=
706      LazyValueInfo::True)
707    return false;
708
709  ++NumAShrs;
710  auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1),
711                                        SDI->getName(), SDI);
712  BO->setDebugLoc(SDI->getDebugLoc());
713  BO->setIsExact(SDI->isExact());
714  SDI->replaceAllUsesWith(BO);
715  SDI->eraseFromParent();
716
717  return true;
718}
719
720static bool processSExt(SExtInst *SDI, LazyValueInfo *LVI) {
721  if (SDI->getType()->isVectorTy())
722    return false;
723
724  Value *Base = SDI->getOperand(0);
725
726  Constant *Zero = ConstantInt::get(Base->getType(), 0);
727  if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, Base, Zero, SDI) !=
728      LazyValueInfo::True)
729    return false;
730
731  ++NumSExt;
732  auto *ZExt =
733      CastInst::CreateZExtOrBitCast(Base, SDI->getType(), SDI->getName(), SDI);
734  ZExt->setDebugLoc(SDI->getDebugLoc());
735  SDI->replaceAllUsesWith(ZExt);
736  SDI->eraseFromParent();
737
738  return true;
739}
740
741static bool processBinOp(BinaryOperator *BinOp, LazyValueInfo *LVI) {
742  using OBO = OverflowingBinaryOperator;
743
744  if (DontAddNoWrapFlags)
745    return false;
746
747  if (BinOp->getType()->isVectorTy())
748    return false;
749
750  bool NSW = BinOp->hasNoSignedWrap();
751  bool NUW = BinOp->hasNoUnsignedWrap();
752  if (NSW && NUW)
753    return false;
754
755  BasicBlock *BB = BinOp->getParent();
756
757  Instruction::BinaryOps Opcode = BinOp->getOpcode();
758  Value *LHS = BinOp->getOperand(0);
759  Value *RHS = BinOp->getOperand(1);
760
761  ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp);
762  ConstantRange RRange = LVI->getConstantRange(RHS, BB, BinOp);
763
764  bool Changed = false;
765  bool NewNUW = false, NewNSW = false;
766  if (!NUW) {
767    ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion(
768        Opcode, RRange, OBO::NoUnsignedWrap);
769    NewNUW = NUWRange.contains(LRange);
770    Changed |= NewNUW;
771  }
772  if (!NSW) {
773    ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion(
774        Opcode, RRange, OBO::NoSignedWrap);
775    NewNSW = NSWRange.contains(LRange);
776    Changed |= NewNSW;
777  }
778
779  setDeducedOverflowingFlags(BinOp, Opcode, NewNSW, NewNUW);
780
781  return Changed;
782}
783
784static bool processAnd(BinaryOperator *BinOp, LazyValueInfo *LVI) {
785  if (BinOp->getType()->isVectorTy())
786    return false;
787
788  // Pattern match (and lhs, C) where C includes a superset of bits which might
789  // be set in lhs.  This is a common truncation idiom created by instcombine.
790  BasicBlock *BB = BinOp->getParent();
791  Value *LHS = BinOp->getOperand(0);
792  ConstantInt *RHS = dyn_cast<ConstantInt>(BinOp->getOperand(1));
793  if (!RHS || !RHS->getValue().isMask())
794    return false;
795
796  ConstantRange LRange = LVI->getConstantRange(LHS, BB, BinOp);
797  if (!LRange.getUnsignedMax().ule(RHS->getValue()))
798    return false;
799
800  BinOp->replaceAllUsesWith(LHS);
801  BinOp->eraseFromParent();
802  NumAnd++;
803  return true;
804}
805
806
807static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) {
808  if (Constant *C = LVI->getConstant(V, At->getParent(), At))
809    return C;
810
811  // TODO: The following really should be sunk inside LVI's core algorithm, or
812  // at least the outer shims around such.
813  auto *C = dyn_cast<CmpInst>(V);
814  if (!C) return nullptr;
815
816  Value *Op0 = C->getOperand(0);
817  Constant *Op1 = dyn_cast<Constant>(C->getOperand(1));
818  if (!Op1) return nullptr;
819
820  LazyValueInfo::Tristate Result =
821    LVI->getPredicateAt(C->getPredicate(), Op0, Op1, At);
822  if (Result == LazyValueInfo::Unknown)
823    return nullptr;
824
825  return (Result == LazyValueInfo::True) ?
826    ConstantInt::getTrue(C->getContext()) :
827    ConstantInt::getFalse(C->getContext());
828}
829
830static bool runImpl(Function &F, LazyValueInfo *LVI, DominatorTree *DT,
831                    const SimplifyQuery &SQ) {
832  bool FnChanged = false;
833  // Visiting in a pre-order depth-first traversal causes us to simplify early
834  // blocks before querying later blocks (which require us to analyze early
835  // blocks).  Eagerly simplifying shallow blocks means there is strictly less
836  // work to do for deep blocks.  This also means we don't visit unreachable
837  // blocks.
838  for (BasicBlock *BB : depth_first(&F.getEntryBlock())) {
839    bool BBChanged = false;
840    for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) {
841      Instruction *II = &*BI++;
842      switch (II->getOpcode()) {
843      case Instruction::Select:
844        BBChanged |= processSelect(cast<SelectInst>(II), LVI);
845        break;
846      case Instruction::PHI:
847        BBChanged |= processPHI(cast<PHINode>(II), LVI, DT, SQ);
848        break;
849      case Instruction::ICmp:
850      case Instruction::FCmp:
851        BBChanged |= processCmp(cast<CmpInst>(II), LVI);
852        break;
853      case Instruction::Load:
854      case Instruction::Store:
855        BBChanged |= processMemAccess(II, LVI);
856        break;
857      case Instruction::Call:
858      case Instruction::Invoke:
859        BBChanged |= processCallSite(CallSite(II), LVI);
860        break;
861      case Instruction::SRem:
862        BBChanged |= processSRem(cast<BinaryOperator>(II), LVI);
863        break;
864      case Instruction::SDiv:
865        BBChanged |= processSDiv(cast<BinaryOperator>(II), LVI);
866        break;
867      case Instruction::UDiv:
868      case Instruction::URem:
869        BBChanged |= processUDivOrURem(cast<BinaryOperator>(II), LVI);
870        break;
871      case Instruction::AShr:
872        BBChanged |= processAShr(cast<BinaryOperator>(II), LVI);
873        break;
874      case Instruction::SExt:
875        BBChanged |= processSExt(cast<SExtInst>(II), LVI);
876        break;
877      case Instruction::Add:
878      case Instruction::Sub:
879      case Instruction::Mul:
880      case Instruction::Shl:
881        BBChanged |= processBinOp(cast<BinaryOperator>(II), LVI);
882        break;
883      case Instruction::And:
884        BBChanged |= processAnd(cast<BinaryOperator>(II), LVI);
885        break;
886      }
887    }
888
889    Instruction *Term = BB->getTerminator();
890    switch (Term->getOpcode()) {
891    case Instruction::Switch:
892      BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI, DT);
893      break;
894    case Instruction::Ret: {
895      auto *RI = cast<ReturnInst>(Term);
896      // Try to determine the return value if we can.  This is mainly here to
897      // simplify the writing of unit tests, but also helps to enable IPO by
898      // constant folding the return values of callees.
899      auto *RetVal = RI->getReturnValue();
900      if (!RetVal) break; // handle "ret void"
901      if (isa<Constant>(RetVal)) break; // nothing to do
902      if (auto *C = getConstantAt(RetVal, RI, LVI)) {
903        ++NumReturns;
904        RI->replaceUsesOfWith(RetVal, C);
905        BBChanged = true;
906      }
907    }
908    }
909
910    FnChanged |= BBChanged;
911  }
912
913  return FnChanged;
914}
915
916bool CorrelatedValuePropagation::runOnFunction(Function &F) {
917  if (skipFunction(F))
918    return false;
919
920  LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI();
921  DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
922
923  return runImpl(F, LVI, DT, getBestSimplifyQuery(*this, F));
924}
925
926PreservedAnalyses
927CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) {
928  LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F);
929  DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
930
931  bool Changed = runImpl(F, LVI, DT, getBestSimplifyQuery(AM, F));
932
933  if (!Changed)
934    return PreservedAnalyses::all();
935  PreservedAnalyses PA;
936  PA.preserve<GlobalsAA>();
937  PA.preserve<DominatorTreeAnalysis>();
938  PA.preserve<LazyValueAnalysis>();
939  return PA;
940}
941