1//===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10#define LLVM_ANALYSIS_MLINLINEADVISOR_H
11
12#include "llvm/Analysis/CallGraph.h"
13#include "llvm/Analysis/InlineAdvisor.h"
14#include "llvm/Analysis/MLModelRunner.h"
15#include "llvm/IR/PassManager.h"
16
17#include <memory>
18#include <unordered_map>
19
20namespace llvm {
21class Module;
22class MLInlineAdvice;
23
24class MLInlineAdvisor : public InlineAdvisor {
25public:
26  MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
27                  std::unique_ptr<MLModelRunner> ModelRunner);
28
29  CallGraph *callGraph() const { return CG.get(); }
30  virtual ~MLInlineAdvisor() = default;
31
32  void onPassEntry() override;
33
34  int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
35  void onSuccessfulInlining(const MLInlineAdvice &Advice,
36                            bool CalleeWasDeleted);
37
38  bool isForcedToStop() const { return ForceStop; }
39  int64_t getLocalCalls(Function &F);
40  const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
41
42protected:
43  std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
44
45  std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
46                                                   bool Advice) override;
47
48  virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
49
50  virtual std::unique_ptr<MLInlineAdvice>
51  getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
52
53  std::unique_ptr<MLModelRunner> ModelRunner;
54
55private:
56  int64_t getModuleIRSize() const;
57
58  std::unique_ptr<CallGraph> CG;
59
60  int64_t NodeCount = 0;
61  int64_t EdgeCount = 0;
62  std::map<const Function *, unsigned> FunctionLevels;
63  const int32_t InitialIRSize = 0;
64  int32_t CurrentIRSize = 0;
65
66  bool ForceStop = false;
67};
68
69/// InlineAdvice that tracks changes post inlining. For that reason, it only
70/// overrides the "successful inlining" extension points.
71class MLInlineAdvice : public InlineAdvice {
72public:
73  MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
74                 OptimizationRemarkEmitter &ORE, bool Recommendation)
75      : InlineAdvice(Advisor, CB, ORE, Recommendation),
76        CallerIRSize(Advisor->isForcedToStop() ? 0
77                                               : Advisor->getIRSize(*Caller)),
78        CalleeIRSize(Advisor->isForcedToStop() ? 0
79                                               : Advisor->getIRSize(*Callee)),
80        CallerAndCalleeEdges(Advisor->isForcedToStop()
81                                 ? 0
82                                 : (Advisor->getLocalCalls(*Caller) +
83                                    Advisor->getLocalCalls(*Callee))) {}
84  virtual ~MLInlineAdvice() = default;
85
86  void recordInliningImpl() override;
87  void recordInliningWithCalleeDeletedImpl() override;
88  void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
89  void recordUnattemptedInliningImpl() override;
90
91  Function *getCaller() const { return Caller; }
92  Function *getCallee() const { return Callee; }
93
94  const int64_t CallerIRSize;
95  const int64_t CalleeIRSize;
96  const int64_t CallerAndCalleeEdges;
97
98private:
99  void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
100
101  MLInlineAdvisor *getAdvisor() const {
102    return static_cast<MLInlineAdvisor *>(Advisor);
103  };
104};
105
106} // namespace llvm
107
108#endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
109