BranchProbabilityInfo.cpp revision 360784
1//===- BranchProbabilityInfo.cpp - Branch Probability Analysis ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Loops should be simplified before this analysis.
10//
11//===----------------------------------------------------------------------===//
12
13#include "llvm/Analysis/BranchProbabilityInfo.h"
14#include "llvm/ADT/PostOrderIterator.h"
15#include "llvm/ADT/SCCIterator.h"
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/Analysis/LoopInfo.h"
19#include "llvm/Analysis/PostDominators.h"
20#include "llvm/Analysis/TargetLibraryInfo.h"
21#include "llvm/IR/Attributes.h"
22#include "llvm/IR/BasicBlock.h"
23#include "llvm/IR/CFG.h"
24#include "llvm/IR/Constants.h"
25#include "llvm/IR/Dominators.h"
26#include "llvm/IR/Function.h"
27#include "llvm/IR/InstrTypes.h"
28#include "llvm/IR/Instruction.h"
29#include "llvm/IR/Instructions.h"
30#include "llvm/IR/LLVMContext.h"
31#include "llvm/IR/Metadata.h"
32#include "llvm/IR/PassManager.h"
33#include "llvm/IR/Type.h"
34#include "llvm/IR/Value.h"
35#include "llvm/InitializePasses.h"
36#include "llvm/Pass.h"
37#include "llvm/Support/BranchProbability.h"
38#include "llvm/Support/Casting.h"
39#include "llvm/Support/CommandLine.h"
40#include "llvm/Support/Debug.h"
41#include "llvm/Support/raw_ostream.h"
42#include <cassert>
43#include <cstdint>
44#include <iterator>
45#include <utility>
46
47using namespace llvm;
48
49#define DEBUG_TYPE "branch-prob"
50
51static cl::opt<bool> PrintBranchProb(
52    "print-bpi", cl::init(false), cl::Hidden,
53    cl::desc("Print the branch probability info."));
54
55cl::opt<std::string> PrintBranchProbFuncName(
56    "print-bpi-func-name", cl::Hidden,
57    cl::desc("The option to specify the name of the function "
58             "whose branch probability info is printed."));
59
60INITIALIZE_PASS_BEGIN(BranchProbabilityInfoWrapperPass, "branch-prob",
61                      "Branch Probability Analysis", false, true)
62INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
63INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
64INITIALIZE_PASS_END(BranchProbabilityInfoWrapperPass, "branch-prob",
65                    "Branch Probability Analysis", false, true)
66
67BranchProbabilityInfoWrapperPass::BranchProbabilityInfoWrapperPass()
68    : FunctionPass(ID) {
69  initializeBranchProbabilityInfoWrapperPassPass(
70      *PassRegistry::getPassRegistry());
71}
72
73char BranchProbabilityInfoWrapperPass::ID = 0;
74
75// Weights are for internal use only. They are used by heuristics to help to
76// estimate edges' probability. Example:
77//
78// Using "Loop Branch Heuristics" we predict weights of edges for the
79// block BB2.
80//         ...
81//          |
82//          V
83//         BB1<-+
84//          |   |
85//          |   | (Weight = 124)
86//          V   |
87//         BB2--+
88//          |
89//          | (Weight = 4)
90//          V
91//         BB3
92//
93// Probability of the edge BB2->BB1 = 124 / (124 + 4) = 0.96875
94// Probability of the edge BB2->BB3 = 4 / (124 + 4) = 0.03125
95static const uint32_t LBH_TAKEN_WEIGHT = 124;
96static const uint32_t LBH_NONTAKEN_WEIGHT = 4;
97// Unlikely edges within a loop are half as likely as other edges
98static const uint32_t LBH_UNLIKELY_WEIGHT = 62;
99
100/// Unreachable-terminating branch taken probability.
101///
102/// This is the probability for a branch being taken to a block that terminates
103/// (eventually) in unreachable. These are predicted as unlikely as possible.
104/// All reachable probability will equally share the remaining part.
105static const BranchProbability UR_TAKEN_PROB = BranchProbability::getRaw(1);
106
107/// Weight for a branch taken going into a cold block.
108///
109/// This is the weight for a branch taken toward a block marked
110/// cold.  A block is marked cold if it's postdominated by a
111/// block containing a call to a cold function.  Cold functions
112/// are those marked with attribute 'cold'.
113static const uint32_t CC_TAKEN_WEIGHT = 4;
114
115/// Weight for a branch not-taken into a cold block.
116///
117/// This is the weight for a branch not taken toward a block marked
118/// cold.
119static const uint32_t CC_NONTAKEN_WEIGHT = 64;
120
121static const uint32_t PH_TAKEN_WEIGHT = 20;
122static const uint32_t PH_NONTAKEN_WEIGHT = 12;
123
124static const uint32_t ZH_TAKEN_WEIGHT = 20;
125static const uint32_t ZH_NONTAKEN_WEIGHT = 12;
126
127static const uint32_t FPH_TAKEN_WEIGHT = 20;
128static const uint32_t FPH_NONTAKEN_WEIGHT = 12;
129
130/// This is the probability for an ordered floating point comparison.
131static const uint32_t FPH_ORD_WEIGHT = 1024 * 1024 - 1;
132/// This is the probability for an unordered floating point comparison, it means
133/// one or two of the operands are NaN. Usually it is used to test for an
134/// exceptional case, so the result is unlikely.
135static const uint32_t FPH_UNO_WEIGHT = 1;
136
137/// Invoke-terminating normal branch taken weight
138///
139/// This is the weight for branching to the normal destination of an invoke
140/// instruction. We expect this to happen most of the time. Set the weight to an
141/// absurdly high value so that nested loops subsume it.
142static const uint32_t IH_TAKEN_WEIGHT = 1024 * 1024 - 1;
143
144/// Invoke-terminating normal branch not-taken weight.
145///
146/// This is the weight for branching to the unwind destination of an invoke
147/// instruction. This is essentially never taken.
148static const uint32_t IH_NONTAKEN_WEIGHT = 1;
149
150static void UpdatePDTWorklist(const BasicBlock *BB, PostDominatorTree *PDT,
151                              SmallVectorImpl<const BasicBlock *> &WorkList,
152                              SmallPtrSetImpl<const BasicBlock *> &TargetSet) {
153  SmallVector<BasicBlock *, 8> Descendants;
154  SmallPtrSet<const BasicBlock *, 16> NewItems;
155
156  PDT->getDescendants(const_cast<BasicBlock *>(BB), Descendants);
157  for (auto *BB : Descendants)
158    if (TargetSet.insert(BB).second)
159      for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
160        if (!TargetSet.count(*PI))
161          NewItems.insert(*PI);
162  WorkList.insert(WorkList.end(), NewItems.begin(), NewItems.end());
163}
164
165/// Compute a set of basic blocks that are post-dominated by unreachables.
166void BranchProbabilityInfo::computePostDominatedByUnreachable(
167    const Function &F, PostDominatorTree *PDT) {
168  SmallVector<const BasicBlock *, 8> WorkList;
169  for (auto &BB : F) {
170    const Instruction *TI = BB.getTerminator();
171    if (TI->getNumSuccessors() == 0) {
172      if (isa<UnreachableInst>(TI) ||
173          // If this block is terminated by a call to
174          // @llvm.experimental.deoptimize then treat it like an unreachable
175          // since the @llvm.experimental.deoptimize call is expected to
176          // practically never execute.
177          BB.getTerminatingDeoptimizeCall())
178        UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByUnreachable);
179    }
180  }
181
182  while (!WorkList.empty()) {
183    const BasicBlock *BB = WorkList.pop_back_val();
184    if (PostDominatedByUnreachable.count(BB))
185      continue;
186    // If the terminator is an InvokeInst, check only the normal destination
187    // block as the unwind edge of InvokeInst is also very unlikely taken.
188    if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
189      if (PostDominatedByUnreachable.count(II->getNormalDest()))
190        UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
191    }
192    // If all the successors are unreachable, BB is unreachable as well.
193    else if (!successors(BB).empty() &&
194             llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
195               return PostDominatedByUnreachable.count(Succ);
196             }))
197      UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByUnreachable);
198  }
199}
200
201/// compute a set of basic blocks that are post-dominated by ColdCalls.
202void BranchProbabilityInfo::computePostDominatedByColdCall(
203    const Function &F, PostDominatorTree *PDT) {
204  SmallVector<const BasicBlock *, 8> WorkList;
205  for (auto &BB : F)
206    for (auto &I : BB)
207      if (const CallInst *CI = dyn_cast<CallInst>(&I))
208        if (CI->hasFnAttr(Attribute::Cold))
209          UpdatePDTWorklist(&BB, PDT, WorkList, PostDominatedByColdCall);
210
211  while (!WorkList.empty()) {
212    const BasicBlock *BB = WorkList.pop_back_val();
213
214    // If the terminator is an InvokeInst, check only the normal destination
215    // block as the unwind edge of InvokeInst is also very unlikely taken.
216    if (auto *II = dyn_cast<InvokeInst>(BB->getTerminator())) {
217      if (PostDominatedByColdCall.count(II->getNormalDest()))
218        UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
219    }
220    // If all of successor are post dominated then BB is also done.
221    else if (!successors(BB).empty() &&
222             llvm::all_of(successors(BB), [this](const BasicBlock *Succ) {
223               return PostDominatedByColdCall.count(Succ);
224             }))
225      UpdatePDTWorklist(BB, PDT, WorkList, PostDominatedByColdCall);
226  }
227}
228
229/// Calculate edge weights for successors lead to unreachable.
230///
231/// Predict that a successor which leads necessarily to an
232/// unreachable-terminated block as extremely unlikely.
233bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) {
234  const Instruction *TI = BB->getTerminator();
235  (void) TI;
236  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
237  assert(!isa<InvokeInst>(TI) &&
238         "Invokes should have already been handled by calcInvokeHeuristics");
239
240  SmallVector<unsigned, 4> UnreachableEdges;
241  SmallVector<unsigned, 4> ReachableEdges;
242
243  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
244    if (PostDominatedByUnreachable.count(*I))
245      UnreachableEdges.push_back(I.getSuccessorIndex());
246    else
247      ReachableEdges.push_back(I.getSuccessorIndex());
248
249  // Skip probabilities if all were reachable.
250  if (UnreachableEdges.empty())
251    return false;
252
253  if (ReachableEdges.empty()) {
254    BranchProbability Prob(1, UnreachableEdges.size());
255    for (unsigned SuccIdx : UnreachableEdges)
256      setEdgeProbability(BB, SuccIdx, Prob);
257    return true;
258  }
259
260  auto UnreachableProb = UR_TAKEN_PROB;
261  auto ReachableProb =
262      (BranchProbability::getOne() - UR_TAKEN_PROB * UnreachableEdges.size()) /
263      ReachableEdges.size();
264
265  for (unsigned SuccIdx : UnreachableEdges)
266    setEdgeProbability(BB, SuccIdx, UnreachableProb);
267  for (unsigned SuccIdx : ReachableEdges)
268    setEdgeProbability(BB, SuccIdx, ReachableProb);
269
270  return true;
271}
272
273// Propagate existing explicit probabilities from either profile data or
274// 'expect' intrinsic processing. Examine metadata against unreachable
275// heuristic. The probability of the edge coming to unreachable block is
276// set to min of metadata and unreachable heuristic.
277bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) {
278  const Instruction *TI = BB->getTerminator();
279  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
280  if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI)))
281    return false;
282
283  MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof);
284  if (!WeightsNode)
285    return false;
286
287  // Check that the number of successors is manageable.
288  assert(TI->getNumSuccessors() < UINT32_MAX && "Too many successors");
289
290  // Ensure there are weights for all of the successors. Note that the first
291  // operand to the metadata node is a name, not a weight.
292  if (WeightsNode->getNumOperands() != TI->getNumSuccessors() + 1)
293    return false;
294
295  // Build up the final weights that will be used in a temporary buffer.
296  // Compute the sum of all weights to later decide whether they need to
297  // be scaled to fit in 32 bits.
298  uint64_t WeightSum = 0;
299  SmallVector<uint32_t, 2> Weights;
300  SmallVector<unsigned, 2> UnreachableIdxs;
301  SmallVector<unsigned, 2> ReachableIdxs;
302  Weights.reserve(TI->getNumSuccessors());
303  for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) {
304    ConstantInt *Weight =
305        mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i));
306    if (!Weight)
307      return false;
308    assert(Weight->getValue().getActiveBits() <= 32 &&
309           "Too many bits for uint32_t");
310    Weights.push_back(Weight->getZExtValue());
311    WeightSum += Weights.back();
312    if (PostDominatedByUnreachable.count(TI->getSuccessor(i - 1)))
313      UnreachableIdxs.push_back(i - 1);
314    else
315      ReachableIdxs.push_back(i - 1);
316  }
317  assert(Weights.size() == TI->getNumSuccessors() && "Checked above");
318
319  // If the sum of weights does not fit in 32 bits, scale every weight down
320  // accordingly.
321  uint64_t ScalingFactor =
322      (WeightSum > UINT32_MAX) ? WeightSum / UINT32_MAX + 1 : 1;
323
324  if (ScalingFactor > 1) {
325    WeightSum = 0;
326    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) {
327      Weights[i] /= ScalingFactor;
328      WeightSum += Weights[i];
329    }
330  }
331  assert(WeightSum <= UINT32_MAX &&
332         "Expected weights to scale down to 32 bits");
333
334  if (WeightSum == 0 || ReachableIdxs.size() == 0) {
335    for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
336      Weights[i] = 1;
337    WeightSum = TI->getNumSuccessors();
338  }
339
340  // Set the probability.
341  SmallVector<BranchProbability, 2> BP;
342  for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
343    BP.push_back({ Weights[i], static_cast<uint32_t>(WeightSum) });
344
345  // Examine the metadata against unreachable heuristic.
346  // If the unreachable heuristic is more strong then we use it for this edge.
347  if (UnreachableIdxs.size() > 0 && ReachableIdxs.size() > 0) {
348    auto ToDistribute = BranchProbability::getZero();
349    auto UnreachableProb = UR_TAKEN_PROB;
350    for (auto i : UnreachableIdxs)
351      if (UnreachableProb < BP[i]) {
352        ToDistribute += BP[i] - UnreachableProb;
353        BP[i] = UnreachableProb;
354      }
355
356    // If we modified the probability of some edges then we must distribute
357    // the difference between reachable blocks.
358    if (ToDistribute > BranchProbability::getZero()) {
359      BranchProbability PerEdge = ToDistribute / ReachableIdxs.size();
360      for (auto i : ReachableIdxs)
361        BP[i] += PerEdge;
362    }
363  }
364
365  for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i)
366    setEdgeProbability(BB, i, BP[i]);
367
368  return true;
369}
370
371/// Calculate edge weights for edges leading to cold blocks.
372///
373/// A cold block is one post-dominated by  a block with a call to a
374/// cold function.  Those edges are unlikely to be taken, so we give
375/// them relatively low weight.
376///
377/// Return true if we could compute the weights for cold edges.
378/// Return false, otherwise.
379bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) {
380  const Instruction *TI = BB->getTerminator();
381  (void) TI;
382  assert(TI->getNumSuccessors() > 1 && "expected more than one successor!");
383  assert(!isa<InvokeInst>(TI) &&
384         "Invokes should have already been handled by calcInvokeHeuristics");
385
386  // Determine which successors are post-dominated by a cold block.
387  SmallVector<unsigned, 4> ColdEdges;
388  SmallVector<unsigned, 4> NormalEdges;
389  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I)
390    if (PostDominatedByColdCall.count(*I))
391      ColdEdges.push_back(I.getSuccessorIndex());
392    else
393      NormalEdges.push_back(I.getSuccessorIndex());
394
395  // Skip probabilities if no cold edges.
396  if (ColdEdges.empty())
397    return false;
398
399  if (NormalEdges.empty()) {
400    BranchProbability Prob(1, ColdEdges.size());
401    for (unsigned SuccIdx : ColdEdges)
402      setEdgeProbability(BB, SuccIdx, Prob);
403    return true;
404  }
405
406  auto ColdProb = BranchProbability::getBranchProbability(
407      CC_TAKEN_WEIGHT,
408      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(ColdEdges.size()));
409  auto NormalProb = BranchProbability::getBranchProbability(
410      CC_NONTAKEN_WEIGHT,
411      (CC_TAKEN_WEIGHT + CC_NONTAKEN_WEIGHT) * uint64_t(NormalEdges.size()));
412
413  for (unsigned SuccIdx : ColdEdges)
414    setEdgeProbability(BB, SuccIdx, ColdProb);
415  for (unsigned SuccIdx : NormalEdges)
416    setEdgeProbability(BB, SuccIdx, NormalProb);
417
418  return true;
419}
420
421// Calculate Edge Weights using "Pointer Heuristics". Predict a comparison
422// between two pointer or pointer and NULL will fail.
423bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) {
424  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
425  if (!BI || !BI->isConditional())
426    return false;
427
428  Value *Cond = BI->getCondition();
429  ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
430  if (!CI || !CI->isEquality())
431    return false;
432
433  Value *LHS = CI->getOperand(0);
434
435  if (!LHS->getType()->isPointerTy())
436    return false;
437
438  assert(CI->getOperand(1)->getType()->isPointerTy());
439
440  // p != 0   ->   isProb = true
441  // p == 0   ->   isProb = false
442  // p != q   ->   isProb = true
443  // p == q   ->   isProb = false;
444  unsigned TakenIdx = 0, NonTakenIdx = 1;
445  bool isProb = CI->getPredicate() == ICmpInst::ICMP_NE;
446  if (!isProb)
447    std::swap(TakenIdx, NonTakenIdx);
448
449  BranchProbability TakenProb(PH_TAKEN_WEIGHT,
450                              PH_TAKEN_WEIGHT + PH_NONTAKEN_WEIGHT);
451  setEdgeProbability(BB, TakenIdx, TakenProb);
452  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
453  return true;
454}
455
456static int getSCCNum(const BasicBlock *BB,
457                     const BranchProbabilityInfo::SccInfo &SccI) {
458  auto SccIt = SccI.SccNums.find(BB);
459  if (SccIt == SccI.SccNums.end())
460    return -1;
461  return SccIt->second;
462}
463
464// Consider any block that is an entry point to the SCC as a header.
465static bool isSCCHeader(const BasicBlock *BB, int SccNum,
466                        BranchProbabilityInfo::SccInfo &SccI) {
467  assert(getSCCNum(BB, SccI) == SccNum);
468
469  // Lazily compute the set of headers for a given SCC and cache the results
470  // in the SccHeaderMap.
471  if (SccI.SccHeaders.size() <= static_cast<unsigned>(SccNum))
472    SccI.SccHeaders.resize(SccNum + 1);
473  auto &HeaderMap = SccI.SccHeaders[SccNum];
474  bool Inserted;
475  BranchProbabilityInfo::SccHeaderMap::iterator HeaderMapIt;
476  std::tie(HeaderMapIt, Inserted) = HeaderMap.insert(std::make_pair(BB, false));
477  if (Inserted) {
478    bool IsHeader = llvm::any_of(make_range(pred_begin(BB), pred_end(BB)),
479                                 [&](const BasicBlock *Pred) {
480                                   return getSCCNum(Pred, SccI) != SccNum;
481                                 });
482    HeaderMapIt->second = IsHeader;
483    return IsHeader;
484  } else
485    return HeaderMapIt->second;
486}
487
488// Compute the unlikely successors to the block BB in the loop L, specifically
489// those that are unlikely because this is a loop, and add them to the
490// UnlikelyBlocks set.
491static void
492computeUnlikelySuccessors(const BasicBlock *BB, Loop *L,
493                          SmallPtrSetImpl<const BasicBlock*> &UnlikelyBlocks) {
494  // Sometimes in a loop we have a branch whose condition is made false by
495  // taking it. This is typically something like
496  //  int n = 0;
497  //  while (...) {
498  //    if (++n >= MAX) {
499  //      n = 0;
500  //    }
501  //  }
502  // In this sort of situation taking the branch means that at the very least it
503  // won't be taken again in the next iteration of the loop, so we should
504  // consider it less likely than a typical branch.
505  //
506  // We detect this by looking back through the graph of PHI nodes that sets the
507  // value that the condition depends on, and seeing if we can reach a successor
508  // block which can be determined to make the condition false.
509  //
510  // FIXME: We currently consider unlikely blocks to be half as likely as other
511  // blocks, but if we consider the example above the likelyhood is actually
512  // 1/MAX. We could therefore be more precise in how unlikely we consider
513  // blocks to be, but it would require more careful examination of the form
514  // of the comparison expression.
515  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
516  if (!BI || !BI->isConditional())
517    return;
518
519  // Check if the branch is based on an instruction compared with a constant
520  CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition());
521  if (!CI || !isa<Instruction>(CI->getOperand(0)) ||
522      !isa<Constant>(CI->getOperand(1)))
523    return;
524
525  // Either the instruction must be a PHI, or a chain of operations involving
526  // constants that ends in a PHI which we can then collapse into a single value
527  // if the PHI value is known.
528  Instruction *CmpLHS = dyn_cast<Instruction>(CI->getOperand(0));
529  PHINode *CmpPHI = dyn_cast<PHINode>(CmpLHS);
530  Constant *CmpConst = dyn_cast<Constant>(CI->getOperand(1));
531  // Collect the instructions until we hit a PHI
532  SmallVector<BinaryOperator *, 1> InstChain;
533  while (!CmpPHI && CmpLHS && isa<BinaryOperator>(CmpLHS) &&
534         isa<Constant>(CmpLHS->getOperand(1))) {
535    // Stop if the chain extends outside of the loop
536    if (!L->contains(CmpLHS))
537      return;
538    InstChain.push_back(cast<BinaryOperator>(CmpLHS));
539    CmpLHS = dyn_cast<Instruction>(CmpLHS->getOperand(0));
540    if (CmpLHS)
541      CmpPHI = dyn_cast<PHINode>(CmpLHS);
542  }
543  if (!CmpPHI || !L->contains(CmpPHI))
544    return;
545
546  // Trace the phi node to find all values that come from successors of BB
547  SmallPtrSet<PHINode*, 8> VisitedInsts;
548  SmallVector<PHINode*, 8> WorkList;
549  WorkList.push_back(CmpPHI);
550  VisitedInsts.insert(CmpPHI);
551  while (!WorkList.empty()) {
552    PHINode *P = WorkList.back();
553    WorkList.pop_back();
554    for (BasicBlock *B : P->blocks()) {
555      // Skip blocks that aren't part of the loop
556      if (!L->contains(B))
557        continue;
558      Value *V = P->getIncomingValueForBlock(B);
559      // If the source is a PHI add it to the work list if we haven't
560      // already visited it.
561      if (PHINode *PN = dyn_cast<PHINode>(V)) {
562        if (VisitedInsts.insert(PN).second)
563          WorkList.push_back(PN);
564        continue;
565      }
566      // If this incoming value is a constant and B is a successor of BB, then
567      // we can constant-evaluate the compare to see if it makes the branch be
568      // taken or not.
569      Constant *CmpLHSConst = dyn_cast<Constant>(V);
570      if (!CmpLHSConst ||
571          std::find(succ_begin(BB), succ_end(BB), B) == succ_end(BB))
572        continue;
573      // First collapse InstChain
574      for (Instruction *I : llvm::reverse(InstChain)) {
575        CmpLHSConst = ConstantExpr::get(I->getOpcode(), CmpLHSConst,
576                                        cast<Constant>(I->getOperand(1)), true);
577        if (!CmpLHSConst)
578          break;
579      }
580      if (!CmpLHSConst)
581        continue;
582      // Now constant-evaluate the compare
583      Constant *Result = ConstantExpr::getCompare(CI->getPredicate(),
584                                                  CmpLHSConst, CmpConst, true);
585      // If the result means we don't branch to the block then that block is
586      // unlikely.
587      if (Result &&
588          ((Result->isZeroValue() && B == BI->getSuccessor(0)) ||
589           (Result->isOneValue() && B == BI->getSuccessor(1))))
590        UnlikelyBlocks.insert(B);
591    }
592  }
593}
594
595// Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges
596// as taken, exiting edges as not-taken.
597bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB,
598                                                     const LoopInfo &LI,
599                                                     SccInfo &SccI) {
600  int SccNum;
601  Loop *L = LI.getLoopFor(BB);
602  if (!L) {
603    SccNum = getSCCNum(BB, SccI);
604    if (SccNum < 0)
605      return false;
606  }
607
608  SmallPtrSet<const BasicBlock*, 8> UnlikelyBlocks;
609  if (L)
610    computeUnlikelySuccessors(BB, L, UnlikelyBlocks);
611
612  SmallVector<unsigned, 8> BackEdges;
613  SmallVector<unsigned, 8> ExitingEdges;
614  SmallVector<unsigned, 8> InEdges; // Edges from header to the loop.
615  SmallVector<unsigned, 8> UnlikelyEdges;
616
617  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
618    // Use LoopInfo if we have it, otherwise fall-back to SCC info to catch
619    // irreducible loops.
620    if (L) {
621      if (UnlikelyBlocks.count(*I) != 0)
622        UnlikelyEdges.push_back(I.getSuccessorIndex());
623      else if (!L->contains(*I))
624        ExitingEdges.push_back(I.getSuccessorIndex());
625      else if (L->getHeader() == *I)
626        BackEdges.push_back(I.getSuccessorIndex());
627      else
628        InEdges.push_back(I.getSuccessorIndex());
629    } else {
630      if (getSCCNum(*I, SccI) != SccNum)
631        ExitingEdges.push_back(I.getSuccessorIndex());
632      else if (isSCCHeader(*I, SccNum, SccI))
633        BackEdges.push_back(I.getSuccessorIndex());
634      else
635        InEdges.push_back(I.getSuccessorIndex());
636    }
637  }
638
639  if (BackEdges.empty() && ExitingEdges.empty() && UnlikelyEdges.empty())
640    return false;
641
642  // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and
643  // normalize them so that they sum up to one.
644  unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
645                   (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) +
646                   (UnlikelyEdges.empty() ? 0 : LBH_UNLIKELY_WEIGHT) +
647                   (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT);
648
649  if (uint32_t numBackEdges = BackEdges.size()) {
650    BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
651    auto Prob = TakenProb / numBackEdges;
652    for (unsigned SuccIdx : BackEdges)
653      setEdgeProbability(BB, SuccIdx, Prob);
654  }
655
656  if (uint32_t numInEdges = InEdges.size()) {
657    BranchProbability TakenProb = BranchProbability(LBH_TAKEN_WEIGHT, Denom);
658    auto Prob = TakenProb / numInEdges;
659    for (unsigned SuccIdx : InEdges)
660      setEdgeProbability(BB, SuccIdx, Prob);
661  }
662
663  if (uint32_t numExitingEdges = ExitingEdges.size()) {
664    BranchProbability NotTakenProb = BranchProbability(LBH_NONTAKEN_WEIGHT,
665                                                       Denom);
666    auto Prob = NotTakenProb / numExitingEdges;
667    for (unsigned SuccIdx : ExitingEdges)
668      setEdgeProbability(BB, SuccIdx, Prob);
669  }
670
671  if (uint32_t numUnlikelyEdges = UnlikelyEdges.size()) {
672    BranchProbability UnlikelyProb = BranchProbability(LBH_UNLIKELY_WEIGHT,
673                                                       Denom);
674    auto Prob = UnlikelyProb / numUnlikelyEdges;
675    for (unsigned SuccIdx : UnlikelyEdges)
676      setEdgeProbability(BB, SuccIdx, Prob);
677  }
678
679  return true;
680}
681
682bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB,
683                                               const TargetLibraryInfo *TLI) {
684  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
685  if (!BI || !BI->isConditional())
686    return false;
687
688  Value *Cond = BI->getCondition();
689  ICmpInst *CI = dyn_cast<ICmpInst>(Cond);
690  if (!CI)
691    return false;
692
693  auto GetConstantInt = [](Value *V) {
694    if (auto *I = dyn_cast<BitCastInst>(V))
695      return dyn_cast<ConstantInt>(I->getOperand(0));
696    return dyn_cast<ConstantInt>(V);
697  };
698
699  Value *RHS = CI->getOperand(1);
700  ConstantInt *CV = GetConstantInt(RHS);
701  if (!CV)
702    return false;
703
704  // If the LHS is the result of AND'ing a value with a single bit bitmask,
705  // we don't have information about probabilities.
706  if (Instruction *LHS = dyn_cast<Instruction>(CI->getOperand(0)))
707    if (LHS->getOpcode() == Instruction::And)
708      if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(LHS->getOperand(1)))
709        if (AndRHS->getValue().isPowerOf2())
710          return false;
711
712  // Check if the LHS is the return value of a library function
713  LibFunc Func = NumLibFuncs;
714  if (TLI)
715    if (CallInst *Call = dyn_cast<CallInst>(CI->getOperand(0)))
716      if (Function *CalledFn = Call->getCalledFunction())
717        TLI->getLibFunc(*CalledFn, Func);
718
719  bool isProb;
720  if (Func == LibFunc_strcasecmp ||
721      Func == LibFunc_strcmp ||
722      Func == LibFunc_strncasecmp ||
723      Func == LibFunc_strncmp ||
724      Func == LibFunc_memcmp) {
725    // strcmp and similar functions return zero, negative, or positive, if the
726    // first string is equal, less, or greater than the second. We consider it
727    // likely that the strings are not equal, so a comparison with zero is
728    // probably false, but also a comparison with any other number is also
729    // probably false given that what exactly is returned for nonzero values is
730    // not specified. Any kind of comparison other than equality we know
731    // nothing about.
732    switch (CI->getPredicate()) {
733    case CmpInst::ICMP_EQ:
734      isProb = false;
735      break;
736    case CmpInst::ICMP_NE:
737      isProb = true;
738      break;
739    default:
740      return false;
741    }
742  } else if (CV->isZero()) {
743    switch (CI->getPredicate()) {
744    case CmpInst::ICMP_EQ:
745      // X == 0   ->  Unlikely
746      isProb = false;
747      break;
748    case CmpInst::ICMP_NE:
749      // X != 0   ->  Likely
750      isProb = true;
751      break;
752    case CmpInst::ICMP_SLT:
753      // X < 0   ->  Unlikely
754      isProb = false;
755      break;
756    case CmpInst::ICMP_SGT:
757      // X > 0   ->  Likely
758      isProb = true;
759      break;
760    default:
761      return false;
762    }
763  } else if (CV->isOne() && CI->getPredicate() == CmpInst::ICMP_SLT) {
764    // InstCombine canonicalizes X <= 0 into X < 1.
765    // X <= 0   ->  Unlikely
766    isProb = false;
767  } else if (CV->isMinusOne()) {
768    switch (CI->getPredicate()) {
769    case CmpInst::ICMP_EQ:
770      // X == -1  ->  Unlikely
771      isProb = false;
772      break;
773    case CmpInst::ICMP_NE:
774      // X != -1  ->  Likely
775      isProb = true;
776      break;
777    case CmpInst::ICMP_SGT:
778      // InstCombine canonicalizes X >= 0 into X > -1.
779      // X >= 0   ->  Likely
780      isProb = true;
781      break;
782    default:
783      return false;
784    }
785  } else {
786    return false;
787  }
788
789  unsigned TakenIdx = 0, NonTakenIdx = 1;
790
791  if (!isProb)
792    std::swap(TakenIdx, NonTakenIdx);
793
794  BranchProbability TakenProb(ZH_TAKEN_WEIGHT,
795                              ZH_TAKEN_WEIGHT + ZH_NONTAKEN_WEIGHT);
796  setEdgeProbability(BB, TakenIdx, TakenProb);
797  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
798  return true;
799}
800
801bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) {
802  const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator());
803  if (!BI || !BI->isConditional())
804    return false;
805
806  Value *Cond = BI->getCondition();
807  FCmpInst *FCmp = dyn_cast<FCmpInst>(Cond);
808  if (!FCmp)
809    return false;
810
811  uint32_t TakenWeight = FPH_TAKEN_WEIGHT;
812  uint32_t NontakenWeight = FPH_NONTAKEN_WEIGHT;
813  bool isProb;
814  if (FCmp->isEquality()) {
815    // f1 == f2 -> Unlikely
816    // f1 != f2 -> Likely
817    isProb = !FCmp->isTrueWhenEqual();
818  } else if (FCmp->getPredicate() == FCmpInst::FCMP_ORD) {
819    // !isnan -> Likely
820    isProb = true;
821    TakenWeight = FPH_ORD_WEIGHT;
822    NontakenWeight = FPH_UNO_WEIGHT;
823  } else if (FCmp->getPredicate() == FCmpInst::FCMP_UNO) {
824    // isnan -> Unlikely
825    isProb = false;
826    TakenWeight = FPH_ORD_WEIGHT;
827    NontakenWeight = FPH_UNO_WEIGHT;
828  } else {
829    return false;
830  }
831
832  unsigned TakenIdx = 0, NonTakenIdx = 1;
833
834  if (!isProb)
835    std::swap(TakenIdx, NonTakenIdx);
836
837  BranchProbability TakenProb(TakenWeight, TakenWeight + NontakenWeight);
838  setEdgeProbability(BB, TakenIdx, TakenProb);
839  setEdgeProbability(BB, NonTakenIdx, TakenProb.getCompl());
840  return true;
841}
842
843bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) {
844  const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator());
845  if (!II)
846    return false;
847
848  BranchProbability TakenProb(IH_TAKEN_WEIGHT,
849                              IH_TAKEN_WEIGHT + IH_NONTAKEN_WEIGHT);
850  setEdgeProbability(BB, 0 /*Index for Normal*/, TakenProb);
851  setEdgeProbability(BB, 1 /*Index for Unwind*/, TakenProb.getCompl());
852  return true;
853}
854
855void BranchProbabilityInfo::releaseMemory() {
856  Probs.clear();
857}
858
859void BranchProbabilityInfo::print(raw_ostream &OS) const {
860  OS << "---- Branch Probabilities ----\n";
861  // We print the probabilities from the last function the analysis ran over,
862  // or the function it is currently running over.
863  assert(LastF && "Cannot print prior to running over a function");
864  for (const auto &BI : *LastF) {
865    for (succ_const_iterator SI = succ_begin(&BI), SE = succ_end(&BI); SI != SE;
866         ++SI) {
867      printEdgeProbability(OS << "  ", &BI, *SI);
868    }
869  }
870}
871
872bool BranchProbabilityInfo::
873isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const {
874  // Hot probability is at least 4/5 = 80%
875  // FIXME: Compare against a static "hot" BranchProbability.
876  return getEdgeProbability(Src, Dst) > BranchProbability(4, 5);
877}
878
879const BasicBlock *
880BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const {
881  auto MaxProb = BranchProbability::getZero();
882  const BasicBlock *MaxSucc = nullptr;
883
884  for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) {
885    const BasicBlock *Succ = *I;
886    auto Prob = getEdgeProbability(BB, Succ);
887    if (Prob > MaxProb) {
888      MaxProb = Prob;
889      MaxSucc = Succ;
890    }
891  }
892
893  // Hot probability is at least 4/5 = 80%
894  if (MaxProb > BranchProbability(4, 5))
895    return MaxSucc;
896
897  return nullptr;
898}
899
900/// Get the raw edge probability for the edge. If can't find it, return a
901/// default probability 1/N where N is the number of successors. Here an edge is
902/// specified using PredBlock and an
903/// index to the successors.
904BranchProbability
905BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
906                                          unsigned IndexInSuccessors) const {
907  auto I = Probs.find(std::make_pair(Src, IndexInSuccessors));
908
909  if (I != Probs.end())
910    return I->second;
911
912  return {1, static_cast<uint32_t>(succ_size(Src))};
913}
914
915BranchProbability
916BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
917                                          succ_const_iterator Dst) const {
918  return getEdgeProbability(Src, Dst.getSuccessorIndex());
919}
920
921/// Get the raw edge probability calculated for the block pair. This returns the
922/// sum of all raw edge probabilities from Src to Dst.
923BranchProbability
924BranchProbabilityInfo::getEdgeProbability(const BasicBlock *Src,
925                                          const BasicBlock *Dst) const {
926  auto Prob = BranchProbability::getZero();
927  bool FoundProb = false;
928  for (succ_const_iterator I = succ_begin(Src), E = succ_end(Src); I != E; ++I)
929    if (*I == Dst) {
930      auto MapI = Probs.find(std::make_pair(Src, I.getSuccessorIndex()));
931      if (MapI != Probs.end()) {
932        FoundProb = true;
933        Prob += MapI->second;
934      }
935    }
936  uint32_t succ_num = std::distance(succ_begin(Src), succ_end(Src));
937  return FoundProb ? Prob : BranchProbability(1, succ_num);
938}
939
940/// Set the edge probability for a given edge specified by PredBlock and an
941/// index to the successors.
942void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src,
943                                               unsigned IndexInSuccessors,
944                                               BranchProbability Prob) {
945  Probs[std::make_pair(Src, IndexInSuccessors)] = Prob;
946  Handles.insert(BasicBlockCallbackVH(Src, this));
947  LLVM_DEBUG(dbgs() << "set edge " << Src->getName() << " -> "
948                    << IndexInSuccessors << " successor probability to " << Prob
949                    << "\n");
950}
951
952raw_ostream &
953BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS,
954                                            const BasicBlock *Src,
955                                            const BasicBlock *Dst) const {
956  const BranchProbability Prob = getEdgeProbability(Src, Dst);
957  OS << "edge " << Src->getName() << " -> " << Dst->getName()
958     << " probability is " << Prob
959     << (isEdgeHot(Src, Dst) ? " [HOT edge]\n" : "\n");
960
961  return OS;
962}
963
964void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) {
965  for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) {
966    auto Key = I->first;
967    if (Key.first == BB)
968      Probs.erase(Key);
969  }
970}
971
972void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI,
973                                      const TargetLibraryInfo *TLI) {
974  LLVM_DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName()
975                    << " ----\n\n");
976  LastF = &F; // Store the last function we ran on for printing.
977  assert(PostDominatedByUnreachable.empty());
978  assert(PostDominatedByColdCall.empty());
979
980  // Record SCC numbers of blocks in the CFG to identify irreducible loops.
981  // FIXME: We could only calculate this if the CFG is known to be irreducible
982  // (perhaps cache this info in LoopInfo if we can easily calculate it there?).
983  int SccNum = 0;
984  SccInfo SccI;
985  for (scc_iterator<const Function *> It = scc_begin(&F); !It.isAtEnd();
986       ++It, ++SccNum) {
987    // Ignore single-block SCCs since they either aren't loops or LoopInfo will
988    // catch them.
989    const std::vector<const BasicBlock *> &Scc = *It;
990    if (Scc.size() == 1)
991      continue;
992
993    LLVM_DEBUG(dbgs() << "BPI: SCC " << SccNum << ":");
994    for (auto *BB : Scc) {
995      LLVM_DEBUG(dbgs() << " " << BB->getName());
996      SccI.SccNums[BB] = SccNum;
997    }
998    LLVM_DEBUG(dbgs() << "\n");
999  }
1000
1001  std::unique_ptr<PostDominatorTree> PDT =
1002      std::make_unique<PostDominatorTree>(const_cast<Function &>(F));
1003  computePostDominatedByUnreachable(F, PDT.get());
1004  computePostDominatedByColdCall(F, PDT.get());
1005
1006  // Walk the basic blocks in post-order so that we can build up state about
1007  // the successors of a block iteratively.
1008  for (auto BB : post_order(&F.getEntryBlock())) {
1009    LLVM_DEBUG(dbgs() << "Computing probabilities for " << BB->getName()
1010                      << "\n");
1011    // If there is no at least two successors, no sense to set probability.
1012    if (BB->getTerminator()->getNumSuccessors() < 2)
1013      continue;
1014    if (calcMetadataWeights(BB))
1015      continue;
1016    if (calcInvokeHeuristics(BB))
1017      continue;
1018    if (calcUnreachableHeuristics(BB))
1019      continue;
1020    if (calcColdCallHeuristics(BB))
1021      continue;
1022    if (calcLoopBranchHeuristics(BB, LI, SccI))
1023      continue;
1024    if (calcPointerHeuristics(BB))
1025      continue;
1026    if (calcZeroHeuristics(BB, TLI))
1027      continue;
1028    if (calcFloatingPointHeuristics(BB))
1029      continue;
1030  }
1031
1032  PostDominatedByUnreachable.clear();
1033  PostDominatedByColdCall.clear();
1034
1035  if (PrintBranchProb &&
1036      (PrintBranchProbFuncName.empty() ||
1037       F.getName().equals(PrintBranchProbFuncName))) {
1038    print(dbgs());
1039  }
1040}
1041
1042void BranchProbabilityInfoWrapperPass::getAnalysisUsage(
1043    AnalysisUsage &AU) const {
1044  // We require DT so it's available when LI is available. The LI updating code
1045  // asserts that DT is also present so if we don't make sure that we have DT
1046  // here, that assert will trigger.
1047  AU.addRequired<DominatorTreeWrapperPass>();
1048  AU.addRequired<LoopInfoWrapperPass>();
1049  AU.addRequired<TargetLibraryInfoWrapperPass>();
1050  AU.setPreservesAll();
1051}
1052
1053bool BranchProbabilityInfoWrapperPass::runOnFunction(Function &F) {
1054  const LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1055  const TargetLibraryInfo &TLI =
1056      getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
1057  BPI.calculate(F, LI, &TLI);
1058  return false;
1059}
1060
1061void BranchProbabilityInfoWrapperPass::releaseMemory() { BPI.releaseMemory(); }
1062
1063void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS,
1064                                             const Module *) const {
1065  BPI.print(OS);
1066}
1067
1068AnalysisKey BranchProbabilityAnalysis::Key;
1069BranchProbabilityInfo
1070BranchProbabilityAnalysis::run(Function &F, FunctionAnalysisManager &AM) {
1071  BranchProbabilityInfo BPI;
1072  BPI.calculate(F, AM.getResult<LoopAnalysis>(F), &AM.getResult<TargetLibraryAnalysis>(F));
1073  return BPI;
1074}
1075
1076PreservedAnalyses
1077BranchProbabilityPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
1078  OS << "Printing analysis results of BPI for function "
1079     << "'" << F.getName() << "':"
1080     << "\n";
1081  AM.getResult<BranchProbabilityAnalysis>(F).print(OS);
1082  return PreservedAnalyses::all();
1083}
1084