PGOInstrumentation.cpp revision 292915
1112158Sdas//===-- PGOInstrumentation.cpp - MST-based PGO Instrumentation ------------===//
2112158Sdas//
3112158Sdas//                      The LLVM Compiler Infrastructure
4112158Sdas//
5112158Sdas// This file is distributed under the University of Illinois Open Source
6112158Sdas// License. See LICENSE.TXT for details.
7112158Sdas//
8112158Sdas//===----------------------------------------------------------------------===//
9112158Sdas//
10112158Sdas// This file implements PGO instrumentation using a minimum spanning tree based
11112158Sdas// on the following paper:
12112158Sdas//   [1] Donald E. Knuth, Francis R. Stevenson. Optimal measurement of points
13112158Sdas//   for program frequency counts. BIT Numerical Mathematics 1973, Volume 13,
14112158Sdas//   Issue 3, pp 313-322
15112158Sdas// The idea of the algorithm based on the fact that for each node (except for
16112158Sdas// the entry and exit), the sum of incoming edge counts equals the sum of
17112158Sdas// outgoing edge counts. The count of edge on spanning tree can be derived from
18112158Sdas// those edges not on the spanning tree. Knuth proves this method instruments
19112158Sdas// the minimum number of edges.
20112158Sdas//
21112158Sdas// The minimal spanning tree here is actually a maximum weight tree -- on-tree
22112158Sdas// edges have higher frequencies (more likely to execute). The idea is to
23112158Sdas// instrument those less frequently executed edges to reduce the runtime
24112158Sdas// overhead of instrumented binaries.
25112158Sdas//
26112158Sdas// This file contains two passes:
27112158Sdas// (1) Pass PGOInstrumentationGen which instruments the IR to generate edge
28112158Sdas// count profile, and
29165743Sdas// (2) Pass PGOInstrumentationUse which reads the edge count profile and
30165743Sdas// annotates the branch weights.
31112158Sdas// To get the precise counter information, These two passes need to invoke at
32112158Sdas// the same compilation point (so they see the same IR). For pass
33112158Sdas// PGOInstrumentationGen, the real work is done in instrumentOneFunc(). For
34112158Sdas// pass PGOInstrumentationUse, the real work in done in class PGOUseFunc and
35112158Sdas// the profile is opened in module level and passed to each PGOUseFunc instance.
36112158Sdas// The shared code for PGOInstrumentationGen and PGOInstrumentationUse is put
37112158Sdas// in class FuncPGOInstrumentation.
38112158Sdas//
39112158Sdas// Class PGOEdge represents a CFG edge and some auxiliary information. Class
40112158Sdas// BBInfo contains auxiliary information for each BB. These two classes are used
41112158Sdas// in pass PGOInstrumentationGen. Class PGOUseEdge and UseBBInfo are the derived
42112158Sdas// class of PGOEdge and BBInfo, respectively. They contains extra data structure
43112158Sdas// used in populating profile counters.
44112158Sdas// The MST implementation is in Class CFGMST (CFGMST.h).
45112158Sdas//
46112158Sdas//===----------------------------------------------------------------------===//
47112158Sdas
48112158Sdas#include "llvm/Transforms/Instrumentation.h"
49112158Sdas#include "CFGMST.h"
50112158Sdas#include "llvm/ADT/DenseMap.h"
51112158Sdas#include "llvm/ADT/STLExtras.h"
52112158Sdas#include "llvm/ADT/Statistic.h"
53112158Sdas#include "llvm/Analysis/BlockFrequencyInfo.h"
54112158Sdas#include "llvm/Analysis/BranchProbabilityInfo.h"
55112158Sdas#include "llvm/Analysis/CFG.h"
56112158Sdas#include "llvm/IR/DiagnosticInfo.h"
57112158Sdas#include "llvm/IR/IRBuilder.h"
58112158Sdas#include "llvm/IR/InstIterator.h"
59112158Sdas#include "llvm/IR/Instructions.h"
60112158Sdas#include "llvm/IR/IntrinsicInst.h"
61112158Sdas#include "llvm/IR/MDBuilder.h"
62112158Sdas#include "llvm/IR/Module.h"
63112158Sdas#include "llvm/Pass.h"
64#include "llvm/ProfileData/InstrProfReader.h"
65#include "llvm/Support/BranchProbability.h"
66#include "llvm/Support/Debug.h"
67#include "llvm/Support/JamCRC.h"
68#include "llvm/Transforms/Utils/BasicBlockUtils.h"
69#include <string>
70#include <utility>
71#include <vector>
72
73using namespace llvm;
74
75#define DEBUG_TYPE "pgo-instrumentation"
76
77STATISTIC(NumOfPGOInstrument, "Number of edges instrumented.");
78STATISTIC(NumOfPGOEdge, "Number of edges.");
79STATISTIC(NumOfPGOBB, "Number of basic-blocks.");
80STATISTIC(NumOfPGOSplit, "Number of critical edge splits.");
81STATISTIC(NumOfPGOFunc, "Number of functions having valid profile counts.");
82STATISTIC(NumOfPGOMismatch, "Number of functions having mismatch profile.");
83STATISTIC(NumOfPGOMissing, "Number of functions without profile.");
84
85// Command line option to specify the file to read profile from. This is
86// mainly used for testing.
87static cl::opt<std::string>
88    PGOTestProfileFile("pgo-test-profile-file", cl::init(""), cl::Hidden,
89                       cl::value_desc("filename"),
90                       cl::desc("Specify the path of profile data file. This is"
91                                "mainly for test purpose."));
92
93namespace {
94class PGOInstrumentationGen : public ModulePass {
95public:
96  static char ID;
97
98  PGOInstrumentationGen() : ModulePass(ID) {
99    initializePGOInstrumentationGenPass(*PassRegistry::getPassRegistry());
100  }
101
102  const char *getPassName() const override {
103    return "PGOInstrumentationGenPass";
104  }
105
106private:
107  bool runOnModule(Module &M) override;
108
109  void getAnalysisUsage(AnalysisUsage &AU) const override {
110    AU.addRequired<BlockFrequencyInfoWrapperPass>();
111  }
112};
113
114class PGOInstrumentationUse : public ModulePass {
115public:
116  static char ID;
117
118  // Provide the profile filename as the parameter.
119  PGOInstrumentationUse(std::string Filename = "")
120      : ModulePass(ID), ProfileFileName(Filename) {
121    if (!PGOTestProfileFile.empty())
122      ProfileFileName = PGOTestProfileFile;
123    initializePGOInstrumentationUsePass(*PassRegistry::getPassRegistry());
124  }
125
126  const char *getPassName() const override {
127    return "PGOInstrumentationUsePass";
128  }
129
130private:
131  std::string ProfileFileName;
132  std::unique_ptr<IndexedInstrProfReader> PGOReader;
133  bool runOnModule(Module &M) override;
134
135  void getAnalysisUsage(AnalysisUsage &AU) const override {
136    AU.addRequired<BlockFrequencyInfoWrapperPass>();
137  }
138};
139} // end anonymous namespace
140
141char PGOInstrumentationGen::ID = 0;
142INITIALIZE_PASS_BEGIN(PGOInstrumentationGen, "pgo-instr-gen",
143                      "PGO instrumentation.", false, false)
144INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
145INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
146INITIALIZE_PASS_END(PGOInstrumentationGen, "pgo-instr-gen",
147                    "PGO instrumentation.", false, false)
148
149ModulePass *llvm::createPGOInstrumentationGenPass() {
150  return new PGOInstrumentationGen();
151}
152
153char PGOInstrumentationUse::ID = 0;
154INITIALIZE_PASS_BEGIN(PGOInstrumentationUse, "pgo-instr-use",
155                      "Read PGO instrumentation profile.", false, false)
156INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
157INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass)
158INITIALIZE_PASS_END(PGOInstrumentationUse, "pgo-instr-use",
159                    "Read PGO instrumentation profile.", false, false)
160
161ModulePass *llvm::createPGOInstrumentationUsePass(StringRef Filename) {
162  return new PGOInstrumentationUse(Filename.str());
163}
164
165namespace {
166/// \brief An MST based instrumentation for PGO
167///
168/// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO
169/// in the function level.
170struct PGOEdge {
171  // This class implements the CFG edges. Note the CFG can be a multi-graph.
172  // So there might be multiple edges with same SrcBB and DestBB.
173  const BasicBlock *SrcBB;
174  const BasicBlock *DestBB;
175  uint64_t Weight;
176  bool InMST;
177  bool Removed;
178  bool IsCritical;
179  PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
180      : SrcBB(Src), DestBB(Dest), Weight(W), InMST(false), Removed(false),
181        IsCritical(false) {}
182  // Return the information string of an edge.
183  const std::string infoString() const {
184    return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") +
185            (IsCritical ? "c" : " ") + "  W=" + Twine(Weight)).str();
186  }
187};
188
189// This class stores the auxiliary information for each BB.
190struct BBInfo {
191  BBInfo *Group;
192  uint32_t Index;
193  uint32_t Rank;
194
195  BBInfo(unsigned IX) : Group(this), Index(IX), Rank(0) {}
196
197  // Return the information string of this object.
198  const std::string infoString() const {
199    return (Twine("Index=") + Twine(Index)).str();
200  }
201};
202
203// This class implements the CFG edges. Note the CFG can be a multi-graph.
204template <class Edge, class BBInfo> class FuncPGOInstrumentation {
205private:
206  Function &F;
207  void computeCFGHash();
208
209public:
210  std::string FuncName;
211  GlobalVariable *FuncNameVar;
212  // CFG hash value for this function.
213  uint64_t FunctionHash;
214
215  // The Minimum Spanning Tree of function CFG.
216  CFGMST<Edge, BBInfo> MST;
217
218  // Give an edge, find the BB that will be instrumented.
219  // Return nullptr if there is no BB to be instrumented.
220  BasicBlock *getInstrBB(Edge *E);
221
222  // Return the auxiliary BB information.
223  BBInfo &getBBInfo(const BasicBlock *BB) const { return MST.getBBInfo(BB); }
224
225  // Dump edges and BB information.
226  void dumpInfo(std::string Str = "") const {
227    MST.dumpEdges(dbgs(), Twine("Dump Function ") + FuncName + " Hash: " +
228                          Twine(FunctionHash) + "\t" + Str);
229  }
230
231  FuncPGOInstrumentation(Function &Func, bool CreateGlobalVar = false,
232                         BranchProbabilityInfo *BPI = nullptr,
233                         BlockFrequencyInfo *BFI = nullptr)
234      : F(Func), FunctionHash(0), MST(F, BPI, BFI) {
235    FuncName = getPGOFuncName(F);
236    computeCFGHash();
237    DEBUG(dumpInfo("after CFGMST"));
238
239    NumOfPGOBB += MST.BBInfos.size();
240    for (auto &E : MST.AllEdges) {
241      if (E->Removed)
242        continue;
243      NumOfPGOEdge++;
244      if (!E->InMST)
245        NumOfPGOInstrument++;
246    }
247
248    if (CreateGlobalVar)
249      FuncNameVar = createPGOFuncNameVar(F, FuncName);
250  };
251};
252
253// Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index
254// value of each BB in the CFG. The higher 32 bits record the number of edges.
255template <class Edge, class BBInfo>
256void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() {
257  std::vector<char> Indexes;
258  JamCRC JC;
259  for (auto &BB : F) {
260    const TerminatorInst *TI = BB.getTerminator();
261    for (unsigned I = 0, E = TI->getNumSuccessors(); I != E; ++I) {
262      BasicBlock *Succ = TI->getSuccessor(I);
263      uint32_t Index = getBBInfo(Succ).Index;
264      for (int J = 0; J < 4; J++)
265        Indexes.push_back((char)(Index >> (J * 8)));
266    }
267  }
268  JC.update(Indexes);
269  FunctionHash = (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC();
270}
271
272// Given a CFG E to be instrumented, find which BB to place the instrumented
273// code. The function will split the critical edge if necessary.
274template <class Edge, class BBInfo>
275BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) {
276  if (E->InMST || E->Removed)
277    return nullptr;
278
279  BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
280  BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
281  // For a fake edge, instrument the real BB.
282  if (SrcBB == nullptr)
283    return DestBB;
284  if (DestBB == nullptr)
285    return SrcBB;
286
287  // Instrument the SrcBB if it has a single successor,
288  // otherwise, the DestBB if this is not a critical edge.
289  TerminatorInst *TI = SrcBB->getTerminator();
290  if (TI->getNumSuccessors() <= 1)
291    return SrcBB;
292  if (!E->IsCritical)
293    return DestBB;
294
295  // For a critical edge, we have to split. Instrument the newly
296  // created BB.
297  NumOfPGOSplit++;
298  DEBUG(dbgs() << "Split critical edge: " << getBBInfo(SrcBB).Index << " --> "
299               << getBBInfo(DestBB).Index << "\n");
300  unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
301  BasicBlock *InstrBB = SplitCriticalEdge(TI, SuccNum);
302  assert(InstrBB && "Critical edge is not split");
303
304  E->Removed = true;
305  return InstrBB;
306}
307
308// Visit all edge and instrument the edges not in MST.
309// Critical edges will be split.
310static void instrumentOneFunc(Function &F, Module *M,
311                              BranchProbabilityInfo *BPI,
312                              BlockFrequencyInfo *BFI) {
313  unsigned NumCounters = 0;
314  FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, true, BPI, BFI);
315  for (auto &E : FuncInfo.MST.AllEdges) {
316    if (!E->InMST && !E->Removed)
317      NumCounters++;
318  }
319
320  uint32_t I = 0;
321  for (auto &E : FuncInfo.MST.AllEdges) {
322    BasicBlock *InstrBB = FuncInfo.getInstrBB(E.get());
323    if (!InstrBB)
324      continue;
325
326    IRBuilder<> Builder(InstrBB, InstrBB->getFirstInsertionPt());
327    assert(Builder.GetInsertPoint() != InstrBB->end() &&
328           "Cannot get the Instrumentation point");
329    Type *I8PtrTy = Type::getInt8PtrTy(M->getContext());
330    Builder.CreateCall(
331        Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment),
332        {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy),
333         Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters),
334         Builder.getInt32(I++)});
335  }
336}
337
338// This class represents a CFG edge in profile use compilation.
339struct PGOUseEdge : public PGOEdge {
340  bool CountValid;
341  uint64_t CountValue;
342  PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1)
343      : PGOEdge(Src, Dest, W), CountValid(false), CountValue(0) {}
344
345  // Set edge count value
346  void setEdgeCount(uint64_t Value) {
347    CountValue = Value;
348    CountValid = true;
349  }
350
351  // Return the information string for this object.
352  const std::string infoString() const {
353    if (!CountValid)
354      return PGOEdge::infoString();
355    return (Twine(PGOEdge::infoString()) + "  Count=" + Twine(CountValue)).str();
356  }
357};
358
359typedef SmallVector<PGOUseEdge *, 2> DirectEdges;
360
361// This class stores the auxiliary information for each BB.
362struct UseBBInfo : public BBInfo {
363  uint64_t CountValue;
364  bool CountValid;
365  int32_t UnknownCountInEdge;
366  int32_t UnknownCountOutEdge;
367  DirectEdges InEdges;
368  DirectEdges OutEdges;
369  UseBBInfo(unsigned IX)
370      : BBInfo(IX), CountValue(0), CountValid(false), UnknownCountInEdge(0),
371        UnknownCountOutEdge(0) {}
372  UseBBInfo(unsigned IX, uint64_t C)
373      : BBInfo(IX), CountValue(C), CountValid(true), UnknownCountInEdge(0),
374        UnknownCountOutEdge(0) {}
375
376  // Set the profile count value for this BB.
377  void setBBInfoCount(uint64_t Value) {
378    CountValue = Value;
379    CountValid = true;
380  }
381
382  // Return the information string of this object.
383  const std::string infoString() const {
384    if (!CountValid)
385      return BBInfo::infoString();
386    return (Twine(BBInfo::infoString()) + "  Count=" + Twine(CountValue)).str();
387  }
388};
389
390// Sum up the count values for all the edges.
391static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) {
392  uint64_t Total = 0;
393  for (auto &E : Edges) {
394    if (E->Removed)
395      continue;
396    Total += E->CountValue;
397  }
398  return Total;
399}
400
401class PGOUseFunc {
402private:
403  Function &F;
404  Module *M;
405  // This member stores the shared information with class PGOGenFunc.
406  FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo;
407
408  // Return the auxiliary BB information.
409  UseBBInfo &getBBInfo(const BasicBlock *BB) const {
410    return FuncInfo.getBBInfo(BB);
411  }
412
413  // The maximum count value in the profile. This is only used in PGO use
414  // compilation.
415  uint64_t ProgramMaxCount;
416
417  // Find the Instrumented BB and set the value.
418  void setInstrumentedCounts(const std::vector<uint64_t> &CountFromProfile);
419
420  // Set the edge counter value for the unknown edge -- there should be only
421  // one unknown edge.
422  void setEdgeCount(DirectEdges &Edges, uint64_t Value);
423
424  // Return FuncName string;
425  const std::string getFuncName() const { return FuncInfo.FuncName; }
426
427  // Set the hot/cold inline hints based on the count values.
428  // FIXME: This function should be removed once the functionality in
429  // the inliner is implemented.
430  void applyFunctionAttributes(uint64_t EntryCount, uint64_t MaxCount) {
431    if (ProgramMaxCount == 0)
432      return;
433    // Threshold of the hot functions.
434    const BranchProbability HotFunctionThreshold(1, 100);
435    // Threshold of the cold functions.
436    const BranchProbability ColdFunctionThreshold(2, 10000);
437    if (EntryCount >= HotFunctionThreshold.scale(ProgramMaxCount))
438      F.addFnAttr(llvm::Attribute::InlineHint);
439    else if (MaxCount <= ColdFunctionThreshold.scale(ProgramMaxCount))
440      F.addFnAttr(llvm::Attribute::Cold);
441  }
442
443public:
444  PGOUseFunc(Function &Func, Module *Modu, BranchProbabilityInfo *BPI = nullptr,
445             BlockFrequencyInfo *BFI = nullptr)
446      : F(Func), M(Modu), FuncInfo(Func, false, BPI, BFI) {}
447
448  // Read counts for the instrumented BB from profile.
449  bool readCounters(IndexedInstrProfReader *PGOReader);
450
451  // Populate the counts for all BBs.
452  void populateCounters();
453
454  // Set the branch weights based on the count values.
455  void setBranchWeights();
456};
457
458// Visit all the edges and assign the count value for the instrumented
459// edges and the BB.
460void PGOUseFunc::setInstrumentedCounts(
461    const std::vector<uint64_t> &CountFromProfile) {
462
463  // Use a worklist as we will update the vector during the iteration.
464  std::vector<PGOUseEdge *> WorkList;
465  for (auto &E : FuncInfo.MST.AllEdges)
466    WorkList.push_back(E.get());
467
468  uint32_t I = 0;
469  for (auto &E : WorkList) {
470    BasicBlock *InstrBB = FuncInfo.getInstrBB(E);
471    if (!InstrBB)
472      continue;
473    uint64_t CountValue = CountFromProfile[I++];
474    if (!E->Removed) {
475      getBBInfo(InstrBB).setBBInfoCount(CountValue);
476      E->setEdgeCount(CountValue);
477      continue;
478    }
479
480    // Need to add two new edges.
481    BasicBlock *SrcBB = const_cast<BasicBlock *>(E->SrcBB);
482    BasicBlock *DestBB = const_cast<BasicBlock *>(E->DestBB);
483    // Add new edge of SrcBB->InstrBB.
484    PGOUseEdge &NewEdge = FuncInfo.MST.addEdge(SrcBB, InstrBB, 0);
485    NewEdge.setEdgeCount(CountValue);
486    // Add new edge of InstrBB->DestBB.
487    PGOUseEdge &NewEdge1 = FuncInfo.MST.addEdge(InstrBB, DestBB, 0);
488    NewEdge1.setEdgeCount(CountValue);
489    NewEdge1.InMST = true;
490    getBBInfo(InstrBB).setBBInfoCount(CountValue);
491  }
492}
493
494// Set the count value for the unknown edge. There should be one and only one
495// unknown edge in Edges vector.
496void PGOUseFunc::setEdgeCount(DirectEdges &Edges, uint64_t Value) {
497  for (auto &E : Edges) {
498    if (E->CountValid)
499      continue;
500    E->setEdgeCount(Value);
501
502    getBBInfo(E->SrcBB).UnknownCountOutEdge--;
503    getBBInfo(E->DestBB).UnknownCountInEdge--;
504    return;
505  }
506  llvm_unreachable("Cannot find the unknown count edge");
507}
508
509// Read the profile from ProfileFileName and assign the value to the
510// instrumented BB and the edges. This function also updates ProgramMaxCount.
511// Return true if the profile are successfully read, and false on errors.
512bool PGOUseFunc::readCounters(IndexedInstrProfReader *PGOReader) {
513  auto &Ctx = M->getContext();
514  ErrorOr<InstrProfRecord> Result =
515      PGOReader->getInstrProfRecord(FuncInfo.FuncName, FuncInfo.FunctionHash);
516  if (std::error_code EC = Result.getError()) {
517    if (EC == instrprof_error::unknown_function)
518      NumOfPGOMissing++;
519    else if (EC == instrprof_error::hash_mismatch ||
520             EC == llvm::instrprof_error::malformed)
521      NumOfPGOMismatch++;
522
523    std::string Msg = EC.message() + std::string(" ") + F.getName().str();
524    Ctx.diagnose(
525        DiagnosticInfoPGOProfile(M->getName().data(), Msg, DS_Warning));
526    return false;
527  }
528  std::vector<uint64_t> &CountFromProfile = Result.get().Counts;
529
530  NumOfPGOFunc++;
531  DEBUG(dbgs() << CountFromProfile.size() << " counts\n");
532  uint64_t ValueSum = 0;
533  for (unsigned I = 0, S = CountFromProfile.size(); I < S; I++) {
534    DEBUG(dbgs() << "  " << I << ": " << CountFromProfile[I] << "\n");
535    ValueSum += CountFromProfile[I];
536  }
537
538  DEBUG(dbgs() << "SUM =  " << ValueSum << "\n");
539
540  getBBInfo(nullptr).UnknownCountOutEdge = 2;
541  getBBInfo(nullptr).UnknownCountInEdge = 2;
542
543  setInstrumentedCounts(CountFromProfile);
544  ProgramMaxCount = PGOReader->getMaximumFunctionCount();
545  return true;
546}
547
548// Populate the counters from instrumented BBs to all BBs.
549// In the end of this operation, all BBs should have a valid count value.
550void PGOUseFunc::populateCounters() {
551  // First set up Count variable for all BBs.
552  for (auto &E : FuncInfo.MST.AllEdges) {
553    if (E->Removed)
554      continue;
555
556    const BasicBlock *SrcBB = E->SrcBB;
557    const BasicBlock *DestBB = E->DestBB;
558    UseBBInfo &SrcInfo = getBBInfo(SrcBB);
559    UseBBInfo &DestInfo = getBBInfo(DestBB);
560    SrcInfo.OutEdges.push_back(E.get());
561    DestInfo.InEdges.push_back(E.get());
562    SrcInfo.UnknownCountOutEdge++;
563    DestInfo.UnknownCountInEdge++;
564
565    if (!E->CountValid)
566      continue;
567    DestInfo.UnknownCountInEdge--;
568    SrcInfo.UnknownCountOutEdge--;
569  }
570
571  bool Changes = true;
572  unsigned NumPasses = 0;
573  while (Changes) {
574    NumPasses++;
575    Changes = false;
576
577    // For efficient traversal, it's better to start from the end as most
578    // of the instrumented edges are at the end.
579    for (auto &BB : reverse(F)) {
580      UseBBInfo &Count = getBBInfo(&BB);
581      if (!Count.CountValid) {
582        if (Count.UnknownCountOutEdge == 0) {
583          Count.CountValue = sumEdgeCount(Count.OutEdges);
584          Count.CountValid = true;
585          Changes = true;
586        } else if (Count.UnknownCountInEdge == 0) {
587          Count.CountValue = sumEdgeCount(Count.InEdges);
588          Count.CountValid = true;
589          Changes = true;
590        }
591      }
592      if (Count.CountValid) {
593        if (Count.UnknownCountOutEdge == 1) {
594          uint64_t Total = Count.CountValue - sumEdgeCount(Count.OutEdges);
595          setEdgeCount(Count.OutEdges, Total);
596          Changes = true;
597        }
598        if (Count.UnknownCountInEdge == 1) {
599          uint64_t Total = Count.CountValue - sumEdgeCount(Count.InEdges);
600          setEdgeCount(Count.InEdges, Total);
601          Changes = true;
602        }
603      }
604    }
605  }
606
607  DEBUG(dbgs() << "Populate counts in " << NumPasses << " passes.\n");
608  // Assert every BB has a valid counter.
609  uint64_t FuncEntryCount = getBBInfo(&*F.begin()).CountValue;
610  uint64_t FuncMaxCount = FuncEntryCount;
611  for (auto &BB : F) {
612    assert(getBBInfo(&BB).CountValid && "BB count is not valid");
613    uint64_t Count = getBBInfo(&BB).CountValue;
614    if (Count > FuncMaxCount)
615      FuncMaxCount = Count;
616  }
617  applyFunctionAttributes(FuncEntryCount, FuncMaxCount);
618
619  DEBUG(FuncInfo.dumpInfo("after reading profile."));
620}
621
622// Assign the scaled count values to the BB with multiple out edges.
623void PGOUseFunc::setBranchWeights() {
624  // Generate MD_prof metadata for every branch instruction.
625  DEBUG(dbgs() << "\nSetting branch weights.\n");
626  MDBuilder MDB(M->getContext());
627  for (auto &BB : F) {
628    TerminatorInst *TI = BB.getTerminator();
629    if (TI->getNumSuccessors() < 2)
630      continue;
631    if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI))
632      continue;
633    if (getBBInfo(&BB).CountValue == 0)
634      continue;
635
636    // We have a non-zero Branch BB.
637    const UseBBInfo &BBCountInfo = getBBInfo(&BB);
638    unsigned Size = BBCountInfo.OutEdges.size();
639    SmallVector<unsigned, 2> EdgeCounts(Size, 0);
640    uint64_t MaxCount = 0;
641    for (unsigned s = 0; s < Size; s++) {
642      const PGOUseEdge *E = BBCountInfo.OutEdges[s];
643      const BasicBlock *SrcBB = E->SrcBB;
644      const BasicBlock *DestBB = E->DestBB;
645      if (DestBB == 0)
646        continue;
647      unsigned SuccNum = GetSuccessorNumber(SrcBB, DestBB);
648      uint64_t EdgeCount = E->CountValue;
649      if (EdgeCount > MaxCount)
650        MaxCount = EdgeCount;
651      EdgeCounts[SuccNum] = EdgeCount;
652    }
653    assert(MaxCount > 0 && "Bad max count");
654    uint64_t Scale = calculateCountScale(MaxCount);
655    SmallVector<unsigned, 4> Weights;
656    for (const auto &ECI : EdgeCounts)
657      Weights.push_back(scaleBranchCount(ECI, Scale));
658
659    TI->setMetadata(llvm::LLVMContext::MD_prof,
660                    MDB.createBranchWeights(Weights));
661    DEBUG(dbgs() << "Weight is: ";
662          for (const auto &W : Weights) { dbgs() << W << " "; }
663          dbgs() << "\n";);
664  }
665}
666} // end anonymous namespace
667
668bool PGOInstrumentationGen::runOnModule(Module &M) {
669  for (auto &F : M) {
670    if (F.isDeclaration())
671      continue;
672    BranchProbabilityInfo *BPI =
673        &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI());
674    BlockFrequencyInfo *BFI =
675        &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI());
676    instrumentOneFunc(F, &M, BPI, BFI);
677  }
678  return true;
679}
680
681static void setPGOCountOnFunc(PGOUseFunc &Func,
682                              IndexedInstrProfReader *PGOReader) {
683  if (Func.readCounters(PGOReader)) {
684    Func.populateCounters();
685    Func.setBranchWeights();
686  }
687}
688
689bool PGOInstrumentationUse::runOnModule(Module &M) {
690  DEBUG(dbgs() << "Read in profile counters: ");
691  auto &Ctx = M.getContext();
692  // Read the counter array from file.
693  auto ReaderOrErr = IndexedInstrProfReader::create(ProfileFileName);
694  if (std::error_code EC = ReaderOrErr.getError()) {
695    Ctx.diagnose(
696        DiagnosticInfoPGOProfile(ProfileFileName.data(), EC.message()));
697    return false;
698  }
699
700  PGOReader = std::move(ReaderOrErr.get());
701  if (!PGOReader) {
702    Ctx.diagnose(DiagnosticInfoPGOProfile(ProfileFileName.data(),
703                                          "Cannot get PGOReader"));
704    return false;
705  }
706
707  for (auto &F : M) {
708    if (F.isDeclaration())
709      continue;
710    BranchProbabilityInfo *BPI =
711        &(getAnalysis<BranchProbabilityInfoWrapperPass>(F).getBPI());
712    BlockFrequencyInfo *BFI =
713        &(getAnalysis<BlockFrequencyInfoWrapperPass>(F).getBFI());
714    PGOUseFunc Func(F, &M, BPI, BFI);
715    setPGOCountOnFunc(Func, PGOReader.get());
716  }
717  return true;
718}
719