1212793Sdim//===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===//
2212793Sdim//
3212793Sdim//                     The LLVM Compiler Infrastructure
4212793Sdim//
5212793Sdim// This file is distributed under the University of Illinois Open Source
6212793Sdim// License. See LICENSE.TXT for details.
7212793Sdim//
8212793Sdim//===----------------------------------------------------------------------===//
9212793Sdim//
10212793Sdim// This file implements the Correlated Value Propagation pass.
11212793Sdim//
12212793Sdim//===----------------------------------------------------------------------===//
13212793Sdim
14212793Sdim#define DEBUG_TYPE "correlated-value-propagation"
15212793Sdim#include "llvm/Transforms/Scalar.h"
16249423Sdim#include "llvm/ADT/Statistic.h"
17218893Sdim#include "llvm/Analysis/InstructionSimplify.h"
18212793Sdim#include "llvm/Analysis/LazyValueInfo.h"
19249423Sdim#include "llvm/IR/Constants.h"
20249423Sdim#include "llvm/IR/Function.h"
21249423Sdim#include "llvm/IR/Instructions.h"
22249423Sdim#include "llvm/Pass.h"
23212793Sdim#include "llvm/Support/CFG.h"
24249423Sdim#include "llvm/Support/Debug.h"
25249423Sdim#include "llvm/Support/raw_ostream.h"
26212793Sdim#include "llvm/Transforms/Utils/Local.h"
27212793Sdimusing namespace llvm;
28212793Sdim
29212793SdimSTATISTIC(NumPhis,      "Number of phis propagated");
30212793SdimSTATISTIC(NumSelects,   "Number of selects propagated");
31212793SdimSTATISTIC(NumMemAccess, "Number of memory access targets propagated");
32212793SdimSTATISTIC(NumCmps,      "Number of comparisons propagated");
33234353SdimSTATISTIC(NumDeadCases, "Number of switch cases removed");
34212793Sdim
35212793Sdimnamespace {
36212793Sdim  class CorrelatedValuePropagation : public FunctionPass {
37212793Sdim    LazyValueInfo *LVI;
38218893Sdim
39212793Sdim    bool processSelect(SelectInst *SI);
40212793Sdim    bool processPHI(PHINode *P);
41212793Sdim    bool processMemAccess(Instruction *I);
42212793Sdim    bool processCmp(CmpInst *C);
43234353Sdim    bool processSwitch(SwitchInst *SI);
44218893Sdim
45212793Sdim  public:
46212793Sdim    static char ID;
47218893Sdim    CorrelatedValuePropagation(): FunctionPass(ID) {
48218893Sdim     initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry());
49218893Sdim    }
50218893Sdim
51212793Sdim    bool runOnFunction(Function &F);
52218893Sdim
53212793Sdim    virtual void getAnalysisUsage(AnalysisUsage &AU) const {
54212793Sdim      AU.addRequired<LazyValueInfo>();
55212793Sdim    }
56212793Sdim  };
57212793Sdim}
58212793Sdim
59212793Sdimchar CorrelatedValuePropagation::ID = 0;
60218893SdimINITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation",
61218893Sdim                "Value Propagation", false, false)
62218893SdimINITIALIZE_PASS_DEPENDENCY(LazyValueInfo)
63218893SdimINITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation",
64218893Sdim                "Value Propagation", false, false)
65212793Sdim
66212793Sdim// Public interface to the Value Propagation pass
67212793SdimPass *llvm::createCorrelatedValuePropagationPass() {
68212793Sdim  return new CorrelatedValuePropagation();
69212793Sdim}
70212793Sdim
71212793Sdimbool CorrelatedValuePropagation::processSelect(SelectInst *S) {
72212793Sdim  if (S->getType()->isVectorTy()) return false;
73212793Sdim  if (isa<Constant>(S->getOperand(0))) return false;
74218893Sdim
75212793Sdim  Constant *C = LVI->getConstant(S->getOperand(0), S->getParent());
76212793Sdim  if (!C) return false;
77218893Sdim
78212793Sdim  ConstantInt *CI = dyn_cast<ConstantInt>(C);
79212793Sdim  if (!CI) return false;
80218893Sdim
81218893Sdim  Value *ReplaceWith = S->getOperand(1);
82218893Sdim  Value *Other = S->getOperand(2);
83218893Sdim  if (!CI->isOne()) std::swap(ReplaceWith, Other);
84218893Sdim  if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType());
85218893Sdim
86218893Sdim  S->replaceAllUsesWith(ReplaceWith);
87212793Sdim  S->eraseFromParent();
88212793Sdim
89212793Sdim  ++NumSelects;
90218893Sdim
91212793Sdim  return true;
92212793Sdim}
93212793Sdim
94212793Sdimbool CorrelatedValuePropagation::processPHI(PHINode *P) {
95212793Sdim  bool Changed = false;
96218893Sdim
97212793Sdim  BasicBlock *BB = P->getParent();
98212793Sdim  for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) {
99212793Sdim    Value *Incoming = P->getIncomingValue(i);
100212793Sdim    if (isa<Constant>(Incoming)) continue;
101218893Sdim
102249423Sdim    Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB);
103218893Sdim
104249423Sdim    // Look if the incoming value is a select with a constant but LVI tells us
105249423Sdim    // that the incoming value can never be that constant. In that case replace
106249423Sdim    // the incoming value with the other value of the select. This often allows
107249423Sdim    // us to remove the select later.
108249423Sdim    if (!V) {
109249423Sdim      SelectInst *SI = dyn_cast<SelectInst>(Incoming);
110249423Sdim      if (!SI) continue;
111249423Sdim
112249423Sdim      Constant *C = dyn_cast<Constant>(SI->getFalseValue());
113249423Sdim      if (!C) continue;
114249423Sdim
115249423Sdim      if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C,
116249423Sdim                                  P->getIncomingBlock(i), BB) !=
117249423Sdim          LazyValueInfo::False)
118249423Sdim        continue;
119249423Sdim
120249423Sdim      DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n');
121249423Sdim      V = SI->getTrueValue();
122249423Sdim    }
123249423Sdim
124249423Sdim    P->setIncomingValue(i, V);
125212793Sdim    Changed = true;
126212793Sdim  }
127218893Sdim
128218893Sdim  if (Value *V = SimplifyInstruction(P)) {
129218893Sdim    P->replaceAllUsesWith(V);
130212793Sdim    P->eraseFromParent();
131212793Sdim    Changed = true;
132212793Sdim  }
133218893Sdim
134234353Sdim  if (Changed)
135234353Sdim    ++NumPhis;
136218893Sdim
137212793Sdim  return Changed;
138212793Sdim}
139212793Sdim
140212793Sdimbool CorrelatedValuePropagation::processMemAccess(Instruction *I) {
141212793Sdim  Value *Pointer = 0;
142212793Sdim  if (LoadInst *L = dyn_cast<LoadInst>(I))
143212793Sdim    Pointer = L->getPointerOperand();
144212793Sdim  else
145212793Sdim    Pointer = cast<StoreInst>(I)->getPointerOperand();
146218893Sdim
147212793Sdim  if (isa<Constant>(Pointer)) return false;
148218893Sdim
149212793Sdim  Constant *C = LVI->getConstant(Pointer, I->getParent());
150212793Sdim  if (!C) return false;
151218893Sdim
152212793Sdim  ++NumMemAccess;
153212793Sdim  I->replaceUsesOfWith(Pointer, C);
154212793Sdim  return true;
155212793Sdim}
156212793Sdim
157212793Sdim/// processCmp - If the value of this comparison could be determined locally,
158212793Sdim/// constant propagation would already have figured it out.  Instead, walk
159212793Sdim/// the predecessors and statically evaluate the comparison based on information
160212793Sdim/// available on that edge.  If a given static evaluation is true on ALL
161212793Sdim/// incoming edges, then it's true universally and we can simplify the compare.
162212793Sdimbool CorrelatedValuePropagation::processCmp(CmpInst *C) {
163212793Sdim  Value *Op0 = C->getOperand(0);
164212793Sdim  if (isa<Instruction>(Op0) &&
165212793Sdim      cast<Instruction>(Op0)->getParent() == C->getParent())
166212793Sdim    return false;
167218893Sdim
168212793Sdim  Constant *Op1 = dyn_cast<Constant>(C->getOperand(1));
169212793Sdim  if (!Op1) return false;
170218893Sdim
171212793Sdim  pred_iterator PI = pred_begin(C->getParent()), PE = pred_end(C->getParent());
172212793Sdim  if (PI == PE) return false;
173218893Sdim
174218893Sdim  LazyValueInfo::Tristate Result = LVI->getPredicateOnEdge(C->getPredicate(),
175212793Sdim                                    C->getOperand(0), Op1, *PI, C->getParent());
176212793Sdim  if (Result == LazyValueInfo::Unknown) return false;
177212793Sdim
178212793Sdim  ++PI;
179212793Sdim  while (PI != PE) {
180218893Sdim    LazyValueInfo::Tristate Res = LVI->getPredicateOnEdge(C->getPredicate(),
181212793Sdim                                    C->getOperand(0), Op1, *PI, C->getParent());
182212793Sdim    if (Res != Result) return false;
183212793Sdim    ++PI;
184212793Sdim  }
185218893Sdim
186212793Sdim  ++NumCmps;
187218893Sdim
188212793Sdim  if (Result == LazyValueInfo::True)
189212793Sdim    C->replaceAllUsesWith(ConstantInt::getTrue(C->getContext()));
190212793Sdim  else
191212793Sdim    C->replaceAllUsesWith(ConstantInt::getFalse(C->getContext()));
192218893Sdim
193212793Sdim  C->eraseFromParent();
194212793Sdim
195212793Sdim  return true;
196212793Sdim}
197212793Sdim
198234353Sdim/// processSwitch - Simplify a switch instruction by removing cases which can
199234353Sdim/// never fire.  If the uselessness of a case could be determined locally then
200234353Sdim/// constant propagation would already have figured it out.  Instead, walk the
201234353Sdim/// predecessors and statically evaluate cases based on information available
202234353Sdim/// on that edge.  Cases that cannot fire no matter what the incoming edge can
203234353Sdim/// safely be removed.  If a case fires on every incoming edge then the entire
204234353Sdim/// switch can be removed and replaced with a branch to the case destination.
205234353Sdimbool CorrelatedValuePropagation::processSwitch(SwitchInst *SI) {
206234353Sdim  Value *Cond = SI->getCondition();
207234353Sdim  BasicBlock *BB = SI->getParent();
208234353Sdim
209234353Sdim  // If the condition was defined in same block as the switch then LazyValueInfo
210234353Sdim  // currently won't say anything useful about it, though in theory it could.
211234353Sdim  if (isa<Instruction>(Cond) && cast<Instruction>(Cond)->getParent() == BB)
212234353Sdim    return false;
213234353Sdim
214234353Sdim  // If the switch is unreachable then trying to improve it is a waste of time.
215234353Sdim  pred_iterator PB = pred_begin(BB), PE = pred_end(BB);
216234353Sdim  if (PB == PE) return false;
217234353Sdim
218234353Sdim  // Analyse each switch case in turn.  This is done in reverse order so that
219234353Sdim  // removing a case doesn't cause trouble for the iteration.
220234353Sdim  bool Changed = false;
221234353Sdim  for (SwitchInst::CaseIt CI = SI->case_end(), CE = SI->case_begin(); CI-- != CE;
222234353Sdim       ) {
223234353Sdim    ConstantInt *Case = CI.getCaseValue();
224234353Sdim
225234353Sdim    // Check to see if the switch condition is equal to/not equal to the case
226234353Sdim    // value on every incoming edge, equal/not equal being the same each time.
227234353Sdim    LazyValueInfo::Tristate State = LazyValueInfo::Unknown;
228234353Sdim    for (pred_iterator PI = PB; PI != PE; ++PI) {
229234353Sdim      // Is the switch condition equal to the case value?
230234353Sdim      LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ,
231234353Sdim                                                              Cond, Case, *PI, BB);
232234353Sdim      // Give up on this case if nothing is known.
233234353Sdim      if (Value == LazyValueInfo::Unknown) {
234234353Sdim        State = LazyValueInfo::Unknown;
235234353Sdim        break;
236234353Sdim      }
237234353Sdim
238234353Sdim      // If this was the first edge to be visited, record that all other edges
239234353Sdim      // need to give the same result.
240234353Sdim      if (PI == PB) {
241234353Sdim        State = Value;
242234353Sdim        continue;
243234353Sdim      }
244234353Sdim
245234353Sdim      // If this case is known to fire for some edges and known not to fire for
246234353Sdim      // others then there is nothing we can do - give up.
247234353Sdim      if (Value != State) {
248234353Sdim        State = LazyValueInfo::Unknown;
249234353Sdim        break;
250234353Sdim      }
251234353Sdim    }
252234353Sdim
253234353Sdim    if (State == LazyValueInfo::False) {
254234353Sdim      // This case never fires - remove it.
255234353Sdim      CI.getCaseSuccessor()->removePredecessor(BB);
256234353Sdim      SI->removeCase(CI); // Does not invalidate the iterator.
257243830Sdim
258243830Sdim      // The condition can be modified by removePredecessor's PHI simplification
259243830Sdim      // logic.
260243830Sdim      Cond = SI->getCondition();
261243830Sdim
262234353Sdim      ++NumDeadCases;
263234353Sdim      Changed = true;
264234353Sdim    } else if (State == LazyValueInfo::True) {
265234353Sdim      // This case always fires.  Arrange for the switch to be turned into an
266234353Sdim      // unconditional branch by replacing the switch condition with the case
267234353Sdim      // value.
268234353Sdim      SI->setCondition(Case);
269234353Sdim      NumDeadCases += SI->getNumCases();
270234353Sdim      Changed = true;
271234353Sdim      break;
272234353Sdim    }
273234353Sdim  }
274234353Sdim
275234353Sdim  if (Changed)
276234353Sdim    // If the switch has been simplified to the point where it can be replaced
277234353Sdim    // by a branch then do so now.
278234353Sdim    ConstantFoldTerminator(BB);
279234353Sdim
280234353Sdim  return Changed;
281234353Sdim}
282234353Sdim
283212793Sdimbool CorrelatedValuePropagation::runOnFunction(Function &F) {
284212793Sdim  LVI = &getAnalysis<LazyValueInfo>();
285218893Sdim
286212793Sdim  bool FnChanged = false;
287218893Sdim
288212793Sdim  for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ++FI) {
289212793Sdim    bool BBChanged = false;
290212793Sdim    for (BasicBlock::iterator BI = FI->begin(), BE = FI->end(); BI != BE; ) {
291212793Sdim      Instruction *II = BI++;
292212793Sdim      switch (II->getOpcode()) {
293212793Sdim      case Instruction::Select:
294212793Sdim        BBChanged |= processSelect(cast<SelectInst>(II));
295212793Sdim        break;
296212793Sdim      case Instruction::PHI:
297212793Sdim        BBChanged |= processPHI(cast<PHINode>(II));
298212793Sdim        break;
299212793Sdim      case Instruction::ICmp:
300212793Sdim      case Instruction::FCmp:
301212793Sdim        BBChanged |= processCmp(cast<CmpInst>(II));
302212793Sdim        break;
303212793Sdim      case Instruction::Load:
304212793Sdim      case Instruction::Store:
305212793Sdim        BBChanged |= processMemAccess(II);
306212793Sdim        break;
307212793Sdim      }
308212793Sdim    }
309218893Sdim
310234353Sdim    Instruction *Term = FI->getTerminator();
311234353Sdim    switch (Term->getOpcode()) {
312234353Sdim    case Instruction::Switch:
313234353Sdim      BBChanged |= processSwitch(cast<SwitchInst>(Term));
314234353Sdim      break;
315234353Sdim    }
316234353Sdim
317212793Sdim    FnChanged |= BBChanged;
318212793Sdim  }
319218893Sdim
320212793Sdim  return FnChanged;
321212793Sdim}
322