1//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Instrumentation-based profile-guided optimization
10//
11//===----------------------------------------------------------------------===//
12
13#include "CodeGenPGO.h"
14#include "CodeGenFunction.h"
15#include "CoverageMappingGen.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/Intrinsics.h"
19#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/CommandLine.h"
21#include "llvm/Support/Endian.h"
22#include "llvm/Support/FileSystem.h"
23#include "llvm/Support/MD5.h"
24
25static llvm::cl::opt<bool>
26    EnableValueProfiling("enable-value-profiling", llvm::cl::ZeroOrMore,
27                         llvm::cl::desc("Enable value profiling"),
28                         llvm::cl::Hidden, llvm::cl::init(false));
29
30using namespace clang;
31using namespace CodeGen;
32
33void CodeGenPGO::setFuncName(StringRef Name,
34                             llvm::GlobalValue::LinkageTypes Linkage) {
35  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
36  FuncName = llvm::getPGOFuncName(
37      Name, Linkage, CGM.getCodeGenOpts().MainFileName,
38      PGOReader ? PGOReader->getVersion() : llvm::IndexedInstrProf::Version);
39
40  // If we're generating a profile, create a variable for the name.
41  if (CGM.getCodeGenOpts().hasProfileClangInstr())
42    FuncNameVar = llvm::createPGOFuncNameVar(CGM.getModule(), Linkage, FuncName);
43}
44
45void CodeGenPGO::setFuncName(llvm::Function *Fn) {
46  setFuncName(Fn->getName(), Fn->getLinkage());
47  // Create PGOFuncName meta data.
48  llvm::createPGOFuncNameMetadata(*Fn, FuncName);
49}
50
51/// The version of the PGO hash algorithm.
52enum PGOHashVersion : unsigned {
53  PGO_HASH_V1,
54  PGO_HASH_V2,
55  PGO_HASH_V3,
56
57  // Keep this set to the latest hash version.
58  PGO_HASH_LATEST = PGO_HASH_V3
59};
60
61namespace {
62/// Stable hasher for PGO region counters.
63///
64/// PGOHash produces a stable hash of a given function's control flow.
65///
66/// Changing the output of this hash will invalidate all previously generated
67/// profiles -- i.e., don't do it.
68///
69/// \note  When this hash does eventually change (years?), we still need to
70/// support old hashes.  We'll need to pull in the version number from the
71/// profile data format and use the matching hash function.
72class PGOHash {
73  uint64_t Working;
74  unsigned Count;
75  PGOHashVersion HashVersion;
76  llvm::MD5 MD5;
77
78  static const int NumBitsPerType = 6;
79  static const unsigned NumTypesPerWord = sizeof(uint64_t) * 8 / NumBitsPerType;
80  static const unsigned TooBig = 1u << NumBitsPerType;
81
82public:
83  /// Hash values for AST nodes.
84  ///
85  /// Distinct values for AST nodes that have region counters attached.
86  ///
87  /// These values must be stable.  All new members must be added at the end,
88  /// and no members should be removed.  Changing the enumeration value for an
89  /// AST node will affect the hash of every function that contains that node.
90  enum HashType : unsigned char {
91    None = 0,
92    LabelStmt = 1,
93    WhileStmt,
94    DoStmt,
95    ForStmt,
96    CXXForRangeStmt,
97    ObjCForCollectionStmt,
98    SwitchStmt,
99    CaseStmt,
100    DefaultStmt,
101    IfStmt,
102    CXXTryStmt,
103    CXXCatchStmt,
104    ConditionalOperator,
105    BinaryOperatorLAnd,
106    BinaryOperatorLOr,
107    BinaryConditionalOperator,
108    // The preceding values are available with PGO_HASH_V1.
109
110    EndOfScope,
111    IfThenBranch,
112    IfElseBranch,
113    GotoStmt,
114    IndirectGotoStmt,
115    BreakStmt,
116    ContinueStmt,
117    ReturnStmt,
118    ThrowExpr,
119    UnaryOperatorLNot,
120    BinaryOperatorLT,
121    BinaryOperatorGT,
122    BinaryOperatorLE,
123    BinaryOperatorGE,
124    BinaryOperatorEQ,
125    BinaryOperatorNE,
126    // The preceding values are available since PGO_HASH_V2.
127
128    // Keep this last.  It's for the static assert that follows.
129    LastHashType
130  };
131  static_assert(LastHashType <= TooBig, "Too many types in HashType");
132
133  PGOHash(PGOHashVersion HashVersion)
134      : Working(0), Count(0), HashVersion(HashVersion), MD5() {}
135  void combine(HashType Type);
136  uint64_t finalize();
137  PGOHashVersion getHashVersion() const { return HashVersion; }
138};
139const int PGOHash::NumBitsPerType;
140const unsigned PGOHash::NumTypesPerWord;
141const unsigned PGOHash::TooBig;
142
143/// Get the PGO hash version used in the given indexed profile.
144static PGOHashVersion getPGOHashVersion(llvm::IndexedInstrProfReader *PGOReader,
145                                        CodeGenModule &CGM) {
146  if (PGOReader->getVersion() <= 4)
147    return PGO_HASH_V1;
148  if (PGOReader->getVersion() <= 5)
149    return PGO_HASH_V2;
150  return PGO_HASH_V3;
151}
152
153/// A RecursiveASTVisitor that fills a map of statements to PGO counters.
154struct MapRegionCounters : public RecursiveASTVisitor<MapRegionCounters> {
155  using Base = RecursiveASTVisitor<MapRegionCounters>;
156
157  /// The next counter value to assign.
158  unsigned NextCounter;
159  /// The function hash.
160  PGOHash Hash;
161  /// The map of statements to counters.
162  llvm::DenseMap<const Stmt *, unsigned> &CounterMap;
163  /// The profile version.
164  uint64_t ProfileVersion;
165
166  MapRegionCounters(PGOHashVersion HashVersion, uint64_t ProfileVersion,
167                    llvm::DenseMap<const Stmt *, unsigned> &CounterMap)
168      : NextCounter(0), Hash(HashVersion), CounterMap(CounterMap),
169        ProfileVersion(ProfileVersion) {}
170
171  // Blocks and lambdas are handled as separate functions, so we need not
172  // traverse them in the parent context.
173  bool TraverseBlockExpr(BlockExpr *BE) { return true; }
174  bool TraverseLambdaExpr(LambdaExpr *LE) {
175    // Traverse the captures, but not the body.
176    for (auto C : zip(LE->captures(), LE->capture_inits()))
177      TraverseLambdaCapture(LE, &std::get<0>(C), std::get<1>(C));
178    return true;
179  }
180  bool TraverseCapturedStmt(CapturedStmt *CS) { return true; }
181
182  bool VisitDecl(const Decl *D) {
183    switch (D->getKind()) {
184    default:
185      break;
186    case Decl::Function:
187    case Decl::CXXMethod:
188    case Decl::CXXConstructor:
189    case Decl::CXXDestructor:
190    case Decl::CXXConversion:
191    case Decl::ObjCMethod:
192    case Decl::Block:
193    case Decl::Captured:
194      CounterMap[D->getBody()] = NextCounter++;
195      break;
196    }
197    return true;
198  }
199
200  /// If \p S gets a fresh counter, update the counter mappings. Return the
201  /// V1 hash of \p S.
202  PGOHash::HashType updateCounterMappings(Stmt *S) {
203    auto Type = getHashType(PGO_HASH_V1, S);
204    if (Type != PGOHash::None)
205      CounterMap[S] = NextCounter++;
206    return Type;
207  }
208
209  /// The RHS of all logical operators gets a fresh counter in order to count
210  /// how many times the RHS evaluates to true or false, depending on the
211  /// semantics of the operator. This is only valid for ">= v7" of the profile
212  /// version so that we facilitate backward compatibility.
213  bool VisitBinaryOperator(BinaryOperator *S) {
214    if (ProfileVersion >= llvm::IndexedInstrProf::Version7)
215      if (S->isLogicalOp() &&
216          CodeGenFunction::isInstrumentedCondition(S->getRHS()))
217        CounterMap[S->getRHS()] = NextCounter++;
218    return Base::VisitBinaryOperator(S);
219  }
220
221  /// Include \p S in the function hash.
222  bool VisitStmt(Stmt *S) {
223    auto Type = updateCounterMappings(S);
224    if (Hash.getHashVersion() != PGO_HASH_V1)
225      Type = getHashType(Hash.getHashVersion(), S);
226    if (Type != PGOHash::None)
227      Hash.combine(Type);
228    return true;
229  }
230
231  bool TraverseIfStmt(IfStmt *If) {
232    // If we used the V1 hash, use the default traversal.
233    if (Hash.getHashVersion() == PGO_HASH_V1)
234      return Base::TraverseIfStmt(If);
235
236    // Otherwise, keep track of which branch we're in while traversing.
237    VisitStmt(If);
238    for (Stmt *CS : If->children()) {
239      if (!CS)
240        continue;
241      if (CS == If->getThen())
242        Hash.combine(PGOHash::IfThenBranch);
243      else if (CS == If->getElse())
244        Hash.combine(PGOHash::IfElseBranch);
245      TraverseStmt(CS);
246    }
247    Hash.combine(PGOHash::EndOfScope);
248    return true;
249  }
250
251// If the statement type \p N is nestable, and its nesting impacts profile
252// stability, define a custom traversal which tracks the end of the statement
253// in the hash (provided we're not using the V1 hash).
254#define DEFINE_NESTABLE_TRAVERSAL(N)                                           \
255  bool Traverse##N(N *S) {                                                     \
256    Base::Traverse##N(S);                                                      \
257    if (Hash.getHashVersion() != PGO_HASH_V1)                                  \
258      Hash.combine(PGOHash::EndOfScope);                                       \
259    return true;                                                               \
260  }
261
262  DEFINE_NESTABLE_TRAVERSAL(WhileStmt)
263  DEFINE_NESTABLE_TRAVERSAL(DoStmt)
264  DEFINE_NESTABLE_TRAVERSAL(ForStmt)
265  DEFINE_NESTABLE_TRAVERSAL(CXXForRangeStmt)
266  DEFINE_NESTABLE_TRAVERSAL(ObjCForCollectionStmt)
267  DEFINE_NESTABLE_TRAVERSAL(CXXTryStmt)
268  DEFINE_NESTABLE_TRAVERSAL(CXXCatchStmt)
269
270  /// Get version \p HashVersion of the PGO hash for \p S.
271  PGOHash::HashType getHashType(PGOHashVersion HashVersion, const Stmt *S) {
272    switch (S->getStmtClass()) {
273    default:
274      break;
275    case Stmt::LabelStmtClass:
276      return PGOHash::LabelStmt;
277    case Stmt::WhileStmtClass:
278      return PGOHash::WhileStmt;
279    case Stmt::DoStmtClass:
280      return PGOHash::DoStmt;
281    case Stmt::ForStmtClass:
282      return PGOHash::ForStmt;
283    case Stmt::CXXForRangeStmtClass:
284      return PGOHash::CXXForRangeStmt;
285    case Stmt::ObjCForCollectionStmtClass:
286      return PGOHash::ObjCForCollectionStmt;
287    case Stmt::SwitchStmtClass:
288      return PGOHash::SwitchStmt;
289    case Stmt::CaseStmtClass:
290      return PGOHash::CaseStmt;
291    case Stmt::DefaultStmtClass:
292      return PGOHash::DefaultStmt;
293    case Stmt::IfStmtClass:
294      return PGOHash::IfStmt;
295    case Stmt::CXXTryStmtClass:
296      return PGOHash::CXXTryStmt;
297    case Stmt::CXXCatchStmtClass:
298      return PGOHash::CXXCatchStmt;
299    case Stmt::ConditionalOperatorClass:
300      return PGOHash::ConditionalOperator;
301    case Stmt::BinaryConditionalOperatorClass:
302      return PGOHash::BinaryConditionalOperator;
303    case Stmt::BinaryOperatorClass: {
304      const BinaryOperator *BO = cast<BinaryOperator>(S);
305      if (BO->getOpcode() == BO_LAnd)
306        return PGOHash::BinaryOperatorLAnd;
307      if (BO->getOpcode() == BO_LOr)
308        return PGOHash::BinaryOperatorLOr;
309      if (HashVersion >= PGO_HASH_V2) {
310        switch (BO->getOpcode()) {
311        default:
312          break;
313        case BO_LT:
314          return PGOHash::BinaryOperatorLT;
315        case BO_GT:
316          return PGOHash::BinaryOperatorGT;
317        case BO_LE:
318          return PGOHash::BinaryOperatorLE;
319        case BO_GE:
320          return PGOHash::BinaryOperatorGE;
321        case BO_EQ:
322          return PGOHash::BinaryOperatorEQ;
323        case BO_NE:
324          return PGOHash::BinaryOperatorNE;
325        }
326      }
327      break;
328    }
329    }
330
331    if (HashVersion >= PGO_HASH_V2) {
332      switch (S->getStmtClass()) {
333      default:
334        break;
335      case Stmt::GotoStmtClass:
336        return PGOHash::GotoStmt;
337      case Stmt::IndirectGotoStmtClass:
338        return PGOHash::IndirectGotoStmt;
339      case Stmt::BreakStmtClass:
340        return PGOHash::BreakStmt;
341      case Stmt::ContinueStmtClass:
342        return PGOHash::ContinueStmt;
343      case Stmt::ReturnStmtClass:
344        return PGOHash::ReturnStmt;
345      case Stmt::CXXThrowExprClass:
346        return PGOHash::ThrowExpr;
347      case Stmt::UnaryOperatorClass: {
348        const UnaryOperator *UO = cast<UnaryOperator>(S);
349        if (UO->getOpcode() == UO_LNot)
350          return PGOHash::UnaryOperatorLNot;
351        break;
352      }
353      }
354    }
355
356    return PGOHash::None;
357  }
358};
359
360/// A StmtVisitor that propagates the raw counts through the AST and
361/// records the count at statements where the value may change.
362struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
363  /// PGO state.
364  CodeGenPGO &PGO;
365
366  /// A flag that is set when the current count should be recorded on the
367  /// next statement, such as at the exit of a loop.
368  bool RecordNextStmtCount;
369
370  /// The count at the current location in the traversal.
371  uint64_t CurrentCount;
372
373  /// The map of statements to count values.
374  llvm::DenseMap<const Stmt *, uint64_t> &CountMap;
375
376  /// BreakContinueStack - Keep counts of breaks and continues inside loops.
377  struct BreakContinue {
378    uint64_t BreakCount;
379    uint64_t ContinueCount;
380    BreakContinue() : BreakCount(0), ContinueCount(0) {}
381  };
382  SmallVector<BreakContinue, 8> BreakContinueStack;
383
384  ComputeRegionCounts(llvm::DenseMap<const Stmt *, uint64_t> &CountMap,
385                      CodeGenPGO &PGO)
386      : PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {}
387
388  void RecordStmtCount(const Stmt *S) {
389    if (RecordNextStmtCount) {
390      CountMap[S] = CurrentCount;
391      RecordNextStmtCount = false;
392    }
393  }
394
395  /// Set and return the current count.
396  uint64_t setCount(uint64_t Count) {
397    CurrentCount = Count;
398    return Count;
399  }
400
401  void VisitStmt(const Stmt *S) {
402    RecordStmtCount(S);
403    for (const Stmt *Child : S->children())
404      if (Child)
405        this->Visit(Child);
406  }
407
408  void VisitFunctionDecl(const FunctionDecl *D) {
409    // Counter tracks entry to the function body.
410    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
411    CountMap[D->getBody()] = BodyCount;
412    Visit(D->getBody());
413  }
414
415  // Skip lambda expressions. We visit these as FunctionDecls when we're
416  // generating them and aren't interested in the body when generating a
417  // parent context.
418  void VisitLambdaExpr(const LambdaExpr *LE) {}
419
420  void VisitCapturedDecl(const CapturedDecl *D) {
421    // Counter tracks entry to the capture body.
422    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
423    CountMap[D->getBody()] = BodyCount;
424    Visit(D->getBody());
425  }
426
427  void VisitObjCMethodDecl(const ObjCMethodDecl *D) {
428    // Counter tracks entry to the method body.
429    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
430    CountMap[D->getBody()] = BodyCount;
431    Visit(D->getBody());
432  }
433
434  void VisitBlockDecl(const BlockDecl *D) {
435    // Counter tracks entry to the block body.
436    uint64_t BodyCount = setCount(PGO.getRegionCount(D->getBody()));
437    CountMap[D->getBody()] = BodyCount;
438    Visit(D->getBody());
439  }
440
441  void VisitReturnStmt(const ReturnStmt *S) {
442    RecordStmtCount(S);
443    if (S->getRetValue())
444      Visit(S->getRetValue());
445    CurrentCount = 0;
446    RecordNextStmtCount = true;
447  }
448
449  void VisitCXXThrowExpr(const CXXThrowExpr *E) {
450    RecordStmtCount(E);
451    if (E->getSubExpr())
452      Visit(E->getSubExpr());
453    CurrentCount = 0;
454    RecordNextStmtCount = true;
455  }
456
457  void VisitGotoStmt(const GotoStmt *S) {
458    RecordStmtCount(S);
459    CurrentCount = 0;
460    RecordNextStmtCount = true;
461  }
462
463  void VisitLabelStmt(const LabelStmt *S) {
464    RecordNextStmtCount = false;
465    // Counter tracks the block following the label.
466    uint64_t BlockCount = setCount(PGO.getRegionCount(S));
467    CountMap[S] = BlockCount;
468    Visit(S->getSubStmt());
469  }
470
471  void VisitBreakStmt(const BreakStmt *S) {
472    RecordStmtCount(S);
473    assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
474    BreakContinueStack.back().BreakCount += CurrentCount;
475    CurrentCount = 0;
476    RecordNextStmtCount = true;
477  }
478
479  void VisitContinueStmt(const ContinueStmt *S) {
480    RecordStmtCount(S);
481    assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
482    BreakContinueStack.back().ContinueCount += CurrentCount;
483    CurrentCount = 0;
484    RecordNextStmtCount = true;
485  }
486
487  void VisitWhileStmt(const WhileStmt *S) {
488    RecordStmtCount(S);
489    uint64_t ParentCount = CurrentCount;
490
491    BreakContinueStack.push_back(BreakContinue());
492    // Visit the body region first so the break/continue adjustments can be
493    // included when visiting the condition.
494    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
495    CountMap[S->getBody()] = CurrentCount;
496    Visit(S->getBody());
497    uint64_t BackedgeCount = CurrentCount;
498
499    // ...then go back and propagate counts through the condition. The count
500    // at the start of the condition is the sum of the incoming edges,
501    // the backedge from the end of the loop body, and the edges from
502    // continue statements.
503    BreakContinue BC = BreakContinueStack.pop_back_val();
504    uint64_t CondCount =
505        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
506    CountMap[S->getCond()] = CondCount;
507    Visit(S->getCond());
508    setCount(BC.BreakCount + CondCount - BodyCount);
509    RecordNextStmtCount = true;
510  }
511
512  void VisitDoStmt(const DoStmt *S) {
513    RecordStmtCount(S);
514    uint64_t LoopCount = PGO.getRegionCount(S);
515
516    BreakContinueStack.push_back(BreakContinue());
517    // The count doesn't include the fallthrough from the parent scope. Add it.
518    uint64_t BodyCount = setCount(LoopCount + CurrentCount);
519    CountMap[S->getBody()] = BodyCount;
520    Visit(S->getBody());
521    uint64_t BackedgeCount = CurrentCount;
522
523    BreakContinue BC = BreakContinueStack.pop_back_val();
524    // The count at the start of the condition is equal to the count at the
525    // end of the body, plus any continues.
526    uint64_t CondCount = setCount(BackedgeCount + BC.ContinueCount);
527    CountMap[S->getCond()] = CondCount;
528    Visit(S->getCond());
529    setCount(BC.BreakCount + CondCount - LoopCount);
530    RecordNextStmtCount = true;
531  }
532
533  void VisitForStmt(const ForStmt *S) {
534    RecordStmtCount(S);
535    if (S->getInit())
536      Visit(S->getInit());
537
538    uint64_t ParentCount = CurrentCount;
539
540    BreakContinueStack.push_back(BreakContinue());
541    // Visit the body region first. (This is basically the same as a while
542    // loop; see further comments in VisitWhileStmt.)
543    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
544    CountMap[S->getBody()] = BodyCount;
545    Visit(S->getBody());
546    uint64_t BackedgeCount = CurrentCount;
547    BreakContinue BC = BreakContinueStack.pop_back_val();
548
549    // The increment is essentially part of the body but it needs to include
550    // the count for all the continue statements.
551    if (S->getInc()) {
552      uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
553      CountMap[S->getInc()] = IncCount;
554      Visit(S->getInc());
555    }
556
557    // ...then go back and propagate counts through the condition.
558    uint64_t CondCount =
559        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
560    if (S->getCond()) {
561      CountMap[S->getCond()] = CondCount;
562      Visit(S->getCond());
563    }
564    setCount(BC.BreakCount + CondCount - BodyCount);
565    RecordNextStmtCount = true;
566  }
567
568  void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
569    RecordStmtCount(S);
570    if (S->getInit())
571      Visit(S->getInit());
572    Visit(S->getLoopVarStmt());
573    Visit(S->getRangeStmt());
574    Visit(S->getBeginStmt());
575    Visit(S->getEndStmt());
576
577    uint64_t ParentCount = CurrentCount;
578    BreakContinueStack.push_back(BreakContinue());
579    // Visit the body region first. (This is basically the same as a while
580    // loop; see further comments in VisitWhileStmt.)
581    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
582    CountMap[S->getBody()] = BodyCount;
583    Visit(S->getBody());
584    uint64_t BackedgeCount = CurrentCount;
585    BreakContinue BC = BreakContinueStack.pop_back_val();
586
587    // The increment is essentially part of the body but it needs to include
588    // the count for all the continue statements.
589    uint64_t IncCount = setCount(BackedgeCount + BC.ContinueCount);
590    CountMap[S->getInc()] = IncCount;
591    Visit(S->getInc());
592
593    // ...then go back and propagate counts through the condition.
594    uint64_t CondCount =
595        setCount(ParentCount + BackedgeCount + BC.ContinueCount);
596    CountMap[S->getCond()] = CondCount;
597    Visit(S->getCond());
598    setCount(BC.BreakCount + CondCount - BodyCount);
599    RecordNextStmtCount = true;
600  }
601
602  void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
603    RecordStmtCount(S);
604    Visit(S->getElement());
605    uint64_t ParentCount = CurrentCount;
606    BreakContinueStack.push_back(BreakContinue());
607    // Counter tracks the body of the loop.
608    uint64_t BodyCount = setCount(PGO.getRegionCount(S));
609    CountMap[S->getBody()] = BodyCount;
610    Visit(S->getBody());
611    uint64_t BackedgeCount = CurrentCount;
612    BreakContinue BC = BreakContinueStack.pop_back_val();
613
614    setCount(BC.BreakCount + ParentCount + BackedgeCount + BC.ContinueCount -
615             BodyCount);
616    RecordNextStmtCount = true;
617  }
618
619  void VisitSwitchStmt(const SwitchStmt *S) {
620    RecordStmtCount(S);
621    if (S->getInit())
622      Visit(S->getInit());
623    Visit(S->getCond());
624    CurrentCount = 0;
625    BreakContinueStack.push_back(BreakContinue());
626    Visit(S->getBody());
627    // If the switch is inside a loop, add the continue counts.
628    BreakContinue BC = BreakContinueStack.pop_back_val();
629    if (!BreakContinueStack.empty())
630      BreakContinueStack.back().ContinueCount += BC.ContinueCount;
631    // Counter tracks the exit block of the switch.
632    setCount(PGO.getRegionCount(S));
633    RecordNextStmtCount = true;
634  }
635
636  void VisitSwitchCase(const SwitchCase *S) {
637    RecordNextStmtCount = false;
638    // Counter for this particular case. This counts only jumps from the
639    // switch header and does not include fallthrough from the case before
640    // this one.
641    uint64_t CaseCount = PGO.getRegionCount(S);
642    setCount(CurrentCount + CaseCount);
643    // We need the count without fallthrough in the mapping, so it's more useful
644    // for branch probabilities.
645    CountMap[S] = CaseCount;
646    RecordNextStmtCount = true;
647    Visit(S->getSubStmt());
648  }
649
650  void VisitIfStmt(const IfStmt *S) {
651    RecordStmtCount(S);
652    uint64_t ParentCount = CurrentCount;
653    if (S->getInit())
654      Visit(S->getInit());
655    Visit(S->getCond());
656
657    // Counter tracks the "then" part of an if statement. The count for
658    // the "else" part, if it exists, will be calculated from this counter.
659    uint64_t ThenCount = setCount(PGO.getRegionCount(S));
660    CountMap[S->getThen()] = ThenCount;
661    Visit(S->getThen());
662    uint64_t OutCount = CurrentCount;
663
664    uint64_t ElseCount = ParentCount - ThenCount;
665    if (S->getElse()) {
666      setCount(ElseCount);
667      CountMap[S->getElse()] = ElseCount;
668      Visit(S->getElse());
669      OutCount += CurrentCount;
670    } else
671      OutCount += ElseCount;
672    setCount(OutCount);
673    RecordNextStmtCount = true;
674  }
675
676  void VisitCXXTryStmt(const CXXTryStmt *S) {
677    RecordStmtCount(S);
678    Visit(S->getTryBlock());
679    for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
680      Visit(S->getHandler(I));
681    // Counter tracks the continuation block of the try statement.
682    setCount(PGO.getRegionCount(S));
683    RecordNextStmtCount = true;
684  }
685
686  void VisitCXXCatchStmt(const CXXCatchStmt *S) {
687    RecordNextStmtCount = false;
688    // Counter tracks the catch statement's handler block.
689    uint64_t CatchCount = setCount(PGO.getRegionCount(S));
690    CountMap[S] = CatchCount;
691    Visit(S->getHandlerBlock());
692  }
693
694  void VisitAbstractConditionalOperator(const AbstractConditionalOperator *E) {
695    RecordStmtCount(E);
696    uint64_t ParentCount = CurrentCount;
697    Visit(E->getCond());
698
699    // Counter tracks the "true" part of a conditional operator. The
700    // count in the "false" part will be calculated from this counter.
701    uint64_t TrueCount = setCount(PGO.getRegionCount(E));
702    CountMap[E->getTrueExpr()] = TrueCount;
703    Visit(E->getTrueExpr());
704    uint64_t OutCount = CurrentCount;
705
706    uint64_t FalseCount = setCount(ParentCount - TrueCount);
707    CountMap[E->getFalseExpr()] = FalseCount;
708    Visit(E->getFalseExpr());
709    OutCount += CurrentCount;
710
711    setCount(OutCount);
712    RecordNextStmtCount = true;
713  }
714
715  void VisitBinLAnd(const BinaryOperator *E) {
716    RecordStmtCount(E);
717    uint64_t ParentCount = CurrentCount;
718    Visit(E->getLHS());
719    // Counter tracks the right hand side of a logical and operator.
720    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
721    CountMap[E->getRHS()] = RHSCount;
722    Visit(E->getRHS());
723    setCount(ParentCount + RHSCount - CurrentCount);
724    RecordNextStmtCount = true;
725  }
726
727  void VisitBinLOr(const BinaryOperator *E) {
728    RecordStmtCount(E);
729    uint64_t ParentCount = CurrentCount;
730    Visit(E->getLHS());
731    // Counter tracks the right hand side of a logical or operator.
732    uint64_t RHSCount = setCount(PGO.getRegionCount(E));
733    CountMap[E->getRHS()] = RHSCount;
734    Visit(E->getRHS());
735    setCount(ParentCount + RHSCount - CurrentCount);
736    RecordNextStmtCount = true;
737  }
738};
739} // end anonymous namespace
740
741void PGOHash::combine(HashType Type) {
742  // Check that we never combine 0 and only have six bits.
743  assert(Type && "Hash is invalid: unexpected type 0");
744  assert(unsigned(Type) < TooBig && "Hash is invalid: too many types");
745
746  // Pass through MD5 if enough work has built up.
747  if (Count && Count % NumTypesPerWord == 0) {
748    using namespace llvm::support;
749    uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
750    MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
751    Working = 0;
752  }
753
754  // Accumulate the current type.
755  ++Count;
756  Working = Working << NumBitsPerType | Type;
757}
758
759uint64_t PGOHash::finalize() {
760  // Use Working as the hash directly if we never used MD5.
761  if (Count <= NumTypesPerWord)
762    // No need to byte swap here, since none of the math was endian-dependent.
763    // This number will be byte-swapped as required on endianness transitions,
764    // so we will see the same value on the other side.
765    return Working;
766
767  // Check for remaining work in Working.
768  if (Working) {
769    // Keep the buggy behavior from v1 and v2 for backward-compatibility. This
770    // is buggy because it converts a uint64_t into an array of uint8_t.
771    if (HashVersion < PGO_HASH_V3) {
772      MD5.update({(uint8_t)Working});
773    } else {
774      using namespace llvm::support;
775      uint64_t Swapped = endian::byte_swap<uint64_t, little>(Working);
776      MD5.update(llvm::makeArrayRef((uint8_t *)&Swapped, sizeof(Swapped)));
777    }
778  }
779
780  // Finalize the MD5 and return the hash.
781  llvm::MD5::MD5Result Result;
782  MD5.final(Result);
783  return Result.low();
784}
785
786void CodeGenPGO::assignRegionCounters(GlobalDecl GD, llvm::Function *Fn) {
787  const Decl *D = GD.getDecl();
788  if (!D->hasBody())
789    return;
790
791  // Skip CUDA/HIP kernel launch stub functions.
792  if (CGM.getLangOpts().CUDA && !CGM.getLangOpts().CUDAIsDevice &&
793      D->hasAttr<CUDAGlobalAttr>())
794    return;
795
796  bool InstrumentRegions = CGM.getCodeGenOpts().hasProfileClangInstr();
797  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
798  if (!InstrumentRegions && !PGOReader)
799    return;
800  if (D->isImplicit())
801    return;
802  // Constructors and destructors may be represented by several functions in IR.
803  // If so, instrument only base variant, others are implemented by delegation
804  // to the base one, it would be counted twice otherwise.
805  if (CGM.getTarget().getCXXABI().hasConstructorVariants()) {
806    if (const auto *CCD = dyn_cast<CXXConstructorDecl>(D))
807      if (GD.getCtorType() != Ctor_Base &&
808          CodeGenFunction::IsConstructorDelegationValid(CCD))
809        return;
810  }
811  if (isa<CXXDestructorDecl>(D) && GD.getDtorType() != Dtor_Base)
812    return;
813
814  CGM.ClearUnusedCoverageMapping(D);
815  if (Fn->hasFnAttribute(llvm::Attribute::NoProfile))
816    return;
817
818  setFuncName(Fn);
819
820  mapRegionCounters(D);
821  if (CGM.getCodeGenOpts().CoverageMapping)
822    emitCounterRegionMapping(D);
823  if (PGOReader) {
824    SourceManager &SM = CGM.getContext().getSourceManager();
825    loadRegionCounts(PGOReader, SM.isInMainFile(D->getLocation()));
826    computeRegionCounts(D);
827    applyFunctionAttributes(PGOReader, Fn);
828  }
829}
830
831void CodeGenPGO::mapRegionCounters(const Decl *D) {
832  // Use the latest hash version when inserting instrumentation, but use the
833  // version in the indexed profile if we're reading PGO data.
834  PGOHashVersion HashVersion = PGO_HASH_LATEST;
835  uint64_t ProfileVersion = llvm::IndexedInstrProf::Version;
836  if (auto *PGOReader = CGM.getPGOReader()) {
837    HashVersion = getPGOHashVersion(PGOReader, CGM);
838    ProfileVersion = PGOReader->getVersion();
839  }
840
841  RegionCounterMap.reset(new llvm::DenseMap<const Stmt *, unsigned>);
842  MapRegionCounters Walker(HashVersion, ProfileVersion, *RegionCounterMap);
843  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
844    Walker.TraverseDecl(const_cast<FunctionDecl *>(FD));
845  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
846    Walker.TraverseDecl(const_cast<ObjCMethodDecl *>(MD));
847  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
848    Walker.TraverseDecl(const_cast<BlockDecl *>(BD));
849  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
850    Walker.TraverseDecl(const_cast<CapturedDecl *>(CD));
851  assert(Walker.NextCounter > 0 && "no entry counter mapped for decl");
852  NumRegionCounters = Walker.NextCounter;
853  FunctionHash = Walker.Hash.finalize();
854}
855
856bool CodeGenPGO::skipRegionMappingForDecl(const Decl *D) {
857  if (!D->getBody())
858    return true;
859
860  // Skip host-only functions in the CUDA device compilation and device-only
861  // functions in the host compilation. Just roughly filter them out based on
862  // the function attributes. If there are effectively host-only or device-only
863  // ones, their coverage mapping may still be generated.
864  if (CGM.getLangOpts().CUDA &&
865      ((CGM.getLangOpts().CUDAIsDevice && !D->hasAttr<CUDADeviceAttr>() &&
866        !D->hasAttr<CUDAGlobalAttr>()) ||
867       (!CGM.getLangOpts().CUDAIsDevice &&
868        (D->hasAttr<CUDAGlobalAttr>() ||
869         (!D->hasAttr<CUDAHostAttr>() && D->hasAttr<CUDADeviceAttr>())))))
870    return true;
871
872  // Don't map the functions in system headers.
873  const auto &SM = CGM.getContext().getSourceManager();
874  auto Loc = D->getBody()->getBeginLoc();
875  return SM.isInSystemHeader(Loc);
876}
877
878void CodeGenPGO::emitCounterRegionMapping(const Decl *D) {
879  if (skipRegionMappingForDecl(D))
880    return;
881
882  std::string CoverageMapping;
883  llvm::raw_string_ostream OS(CoverageMapping);
884  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
885                                CGM.getContext().getSourceManager(),
886                                CGM.getLangOpts(), RegionCounterMap.get());
887  MappingGen.emitCounterMapping(D, OS);
888  OS.flush();
889
890  if (CoverageMapping.empty())
891    return;
892
893  CGM.getCoverageMapping()->addFunctionMappingRecord(
894      FuncNameVar, FuncName, FunctionHash, CoverageMapping);
895}
896
897void
898CodeGenPGO::emitEmptyCounterMapping(const Decl *D, StringRef Name,
899                                    llvm::GlobalValue::LinkageTypes Linkage) {
900  if (skipRegionMappingForDecl(D))
901    return;
902
903  std::string CoverageMapping;
904  llvm::raw_string_ostream OS(CoverageMapping);
905  CoverageMappingGen MappingGen(*CGM.getCoverageMapping(),
906                                CGM.getContext().getSourceManager(),
907                                CGM.getLangOpts());
908  MappingGen.emitEmptyMapping(D, OS);
909  OS.flush();
910
911  if (CoverageMapping.empty())
912    return;
913
914  setFuncName(Name, Linkage);
915  CGM.getCoverageMapping()->addFunctionMappingRecord(
916      FuncNameVar, FuncName, FunctionHash, CoverageMapping, false);
917}
918
919void CodeGenPGO::computeRegionCounts(const Decl *D) {
920  StmtCountMap.reset(new llvm::DenseMap<const Stmt *, uint64_t>);
921  ComputeRegionCounts Walker(*StmtCountMap, *this);
922  if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
923    Walker.VisitFunctionDecl(FD);
924  else if (const ObjCMethodDecl *MD = dyn_cast_or_null<ObjCMethodDecl>(D))
925    Walker.VisitObjCMethodDecl(MD);
926  else if (const BlockDecl *BD = dyn_cast_or_null<BlockDecl>(D))
927    Walker.VisitBlockDecl(BD);
928  else if (const CapturedDecl *CD = dyn_cast_or_null<CapturedDecl>(D))
929    Walker.VisitCapturedDecl(const_cast<CapturedDecl *>(CD));
930}
931
932void
933CodeGenPGO::applyFunctionAttributes(llvm::IndexedInstrProfReader *PGOReader,
934                                    llvm::Function *Fn) {
935  if (!haveRegionCounts())
936    return;
937
938  uint64_t FunctionCount = getRegionCount(nullptr);
939  Fn->setEntryCount(FunctionCount);
940}
941
942void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, const Stmt *S,
943                                      llvm::Value *StepV) {
944  if (!CGM.getCodeGenOpts().hasProfileClangInstr() || !RegionCounterMap)
945    return;
946  if (!Builder.GetInsertBlock())
947    return;
948
949  unsigned Counter = (*RegionCounterMap)[S];
950  auto *I8PtrTy = llvm::Type::getInt8PtrTy(CGM.getLLVMContext());
951
952  llvm::Value *Args[] = {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy),
953                         Builder.getInt64(FunctionHash),
954                         Builder.getInt32(NumRegionCounters),
955                         Builder.getInt32(Counter), StepV};
956  if (!StepV)
957    Builder.CreateCall(CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment),
958                       makeArrayRef(Args, 4));
959  else
960    Builder.CreateCall(
961        CGM.getIntrinsic(llvm::Intrinsic::instrprof_increment_step),
962        makeArrayRef(Args));
963}
964
965void CodeGenPGO::setValueProfilingFlag(llvm::Module &M) {
966  if (CGM.getCodeGenOpts().hasProfileClangInstr())
967    M.addModuleFlag(llvm::Module::Warning, "EnableValueProfiling",
968                    uint32_t(EnableValueProfiling));
969}
970
971// This method either inserts a call to the profile run-time during
972// instrumentation or puts profile data into metadata for PGO use.
973void CodeGenPGO::valueProfile(CGBuilderTy &Builder, uint32_t ValueKind,
974    llvm::Instruction *ValueSite, llvm::Value *ValuePtr) {
975
976  if (!EnableValueProfiling)
977    return;
978
979  if (!ValuePtr || !ValueSite || !Builder.GetInsertBlock())
980    return;
981
982  if (isa<llvm::Constant>(ValuePtr))
983    return;
984
985  bool InstrumentValueSites = CGM.getCodeGenOpts().hasProfileClangInstr();
986  if (InstrumentValueSites && RegionCounterMap) {
987    auto BuilderInsertPoint = Builder.saveIP();
988    Builder.SetInsertPoint(ValueSite);
989    llvm::Value *Args[5] = {
990        llvm::ConstantExpr::getBitCast(FuncNameVar, Builder.getInt8PtrTy()),
991        Builder.getInt64(FunctionHash),
992        Builder.CreatePtrToInt(ValuePtr, Builder.getInt64Ty()),
993        Builder.getInt32(ValueKind),
994        Builder.getInt32(NumValueSites[ValueKind]++)
995    };
996    Builder.CreateCall(
997        CGM.getIntrinsic(llvm::Intrinsic::instrprof_value_profile), Args);
998    Builder.restoreIP(BuilderInsertPoint);
999    return;
1000  }
1001
1002  llvm::IndexedInstrProfReader *PGOReader = CGM.getPGOReader();
1003  if (PGOReader && haveRegionCounts()) {
1004    // We record the top most called three functions at each call site.
1005    // Profile metadata contains "VP" string identifying this metadata
1006    // as value profiling data, then a uint32_t value for the value profiling
1007    // kind, a uint64_t value for the total number of times the call is
1008    // executed, followed by the function hash and execution count (uint64_t)
1009    // pairs for each function.
1010    if (NumValueSites[ValueKind] >= ProfRecord->getNumValueSites(ValueKind))
1011      return;
1012
1013    llvm::annotateValueSite(CGM.getModule(), *ValueSite, *ProfRecord,
1014                            (llvm::InstrProfValueKind)ValueKind,
1015                            NumValueSites[ValueKind]);
1016
1017    NumValueSites[ValueKind]++;
1018  }
1019}
1020
1021void CodeGenPGO::loadRegionCounts(llvm::IndexedInstrProfReader *PGOReader,
1022                                  bool IsInMainFile) {
1023  CGM.getPGOStats().addVisited(IsInMainFile);
1024  RegionCounts.clear();
1025  llvm::Expected<llvm::InstrProfRecord> RecordExpected =
1026      PGOReader->getInstrProfRecord(FuncName, FunctionHash);
1027  if (auto E = RecordExpected.takeError()) {
1028    auto IPE = llvm::InstrProfError::take(std::move(E));
1029    if (IPE == llvm::instrprof_error::unknown_function)
1030      CGM.getPGOStats().addMissing(IsInMainFile);
1031    else if (IPE == llvm::instrprof_error::hash_mismatch)
1032      CGM.getPGOStats().addMismatched(IsInMainFile);
1033    else if (IPE == llvm::instrprof_error::malformed)
1034      // TODO: Consider a more specific warning for this case.
1035      CGM.getPGOStats().addMismatched(IsInMainFile);
1036    return;
1037  }
1038  ProfRecord =
1039      std::make_unique<llvm::InstrProfRecord>(std::move(RecordExpected.get()));
1040  RegionCounts = ProfRecord->Counts;
1041}
1042
1043/// Calculate what to divide by to scale weights.
1044///
1045/// Given the maximum weight, calculate a divisor that will scale all the
1046/// weights to strictly less than UINT32_MAX.
1047static uint64_t calculateWeightScale(uint64_t MaxWeight) {
1048  return MaxWeight < UINT32_MAX ? 1 : MaxWeight / UINT32_MAX + 1;
1049}
1050
1051/// Scale an individual branch weight (and add 1).
1052///
1053/// Scale a 64-bit weight down to 32-bits using \c Scale.
1054///
1055/// According to Laplace's Rule of Succession, it is better to compute the
1056/// weight based on the count plus 1, so universally add 1 to the value.
1057///
1058/// \pre \c Scale was calculated by \a calculateWeightScale() with a weight no
1059/// greater than \c Weight.
1060static uint32_t scaleBranchWeight(uint64_t Weight, uint64_t Scale) {
1061  assert(Scale && "scale by 0?");
1062  uint64_t Scaled = Weight / Scale + 1;
1063  assert(Scaled <= UINT32_MAX && "overflow 32-bits");
1064  return Scaled;
1065}
1066
1067llvm::MDNode *CodeGenFunction::createProfileWeights(uint64_t TrueCount,
1068                                                    uint64_t FalseCount) const {
1069  // Check for empty weights.
1070  if (!TrueCount && !FalseCount)
1071    return nullptr;
1072
1073  // Calculate how to scale down to 32-bits.
1074  uint64_t Scale = calculateWeightScale(std::max(TrueCount, FalseCount));
1075
1076  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1077  return MDHelper.createBranchWeights(scaleBranchWeight(TrueCount, Scale),
1078                                      scaleBranchWeight(FalseCount, Scale));
1079}
1080
1081llvm::MDNode *
1082CodeGenFunction::createProfileWeights(ArrayRef<uint64_t> Weights) const {
1083  // We need at least two elements to create meaningful weights.
1084  if (Weights.size() < 2)
1085    return nullptr;
1086
1087  // Check for empty weights.
1088  uint64_t MaxWeight = *std::max_element(Weights.begin(), Weights.end());
1089  if (MaxWeight == 0)
1090    return nullptr;
1091
1092  // Calculate how to scale down to 32-bits.
1093  uint64_t Scale = calculateWeightScale(MaxWeight);
1094
1095  SmallVector<uint32_t, 16> ScaledWeights;
1096  ScaledWeights.reserve(Weights.size());
1097  for (uint64_t W : Weights)
1098    ScaledWeights.push_back(scaleBranchWeight(W, Scale));
1099
1100  llvm::MDBuilder MDHelper(CGM.getLLVMContext());
1101  return MDHelper.createBranchWeights(ScaledWeights);
1102}
1103
1104llvm::MDNode *
1105CodeGenFunction::createProfileWeightsForLoop(const Stmt *Cond,
1106                                             uint64_t LoopCount) const {
1107  if (!PGO.haveRegionCounts())
1108    return nullptr;
1109  Optional<uint64_t> CondCount = PGO.getStmtCount(Cond);
1110  if (!CondCount || *CondCount == 0)
1111    return nullptr;
1112  return createProfileWeights(LoopCount,
1113                              std::max(*CondCount, LoopCount) - LoopCount);
1114}
1115