1//===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// 2// 3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4// See https://llvm.org/LICENSE.txt for license information. 5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6// 7//===----------------------------------------------------------------------===// 8 9#ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H 10#define LLVM_ANALYSIS_MLINLINEADVISOR_H 11 12#include "llvm/Analysis/CallGraph.h" 13#include "llvm/Analysis/InlineAdvisor.h" 14#include "llvm/Analysis/MLModelRunner.h" 15#include "llvm/IR/PassManager.h" 16 17#include <memory> 18#include <unordered_map> 19 20namespace llvm { 21class Module; 22class MLInlineAdvice; 23 24class MLInlineAdvisor : public InlineAdvisor { 25public: 26 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, 27 std::unique_ptr<MLModelRunner> ModelRunner); 28 29 CallGraph *callGraph() const { return CG.get(); } 30 virtual ~MLInlineAdvisor() = default; 31 32 void onPassEntry() override; 33 34 int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } 35 void onSuccessfulInlining(const MLInlineAdvice &Advice, 36 bool CalleeWasDeleted); 37 38 bool isForcedToStop() const { return ForceStop; } 39 int64_t getLocalCalls(Function &F); 40 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } 41 42protected: 43 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; 44 45 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, 46 bool Advice) override; 47 48 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); 49 50 virtual std::unique_ptr<MLInlineAdvice> 51 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); 52 53 std::unique_ptr<MLModelRunner> ModelRunner; 54 55private: 56 int64_t getModuleIRSize() const; 57 58 std::unique_ptr<CallGraph> CG; 59 60 int64_t NodeCount = 0; 61 int64_t EdgeCount = 0; 62 std::map<const Function *, unsigned> FunctionLevels; 63 const int32_t InitialIRSize = 0; 64 int32_t CurrentIRSize = 0; 65 66 bool ForceStop = false; 67}; 68 69/// InlineAdvice that tracks changes post inlining. For that reason, it only 70/// overrides the "successful inlining" extension points. 71class MLInlineAdvice : public InlineAdvice { 72public: 73 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 74 OptimizationRemarkEmitter &ORE, bool Recommendation) 75 : InlineAdvice(Advisor, CB, ORE, Recommendation), 76 CallerIRSize(Advisor->isForcedToStop() ? 0 77 : Advisor->getIRSize(*Caller)), 78 CalleeIRSize(Advisor->isForcedToStop() ? 0 79 : Advisor->getIRSize(*Callee)), 80 CallerAndCalleeEdges(Advisor->isForcedToStop() 81 ? 0 82 : (Advisor->getLocalCalls(*Caller) + 83 Advisor->getLocalCalls(*Callee))) {} 84 virtual ~MLInlineAdvice() = default; 85 86 void recordInliningImpl() override; 87 void recordInliningWithCalleeDeletedImpl() override; 88 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; 89 void recordUnattemptedInliningImpl() override; 90 91 Function *getCaller() const { return Caller; } 92 Function *getCallee() const { return Callee; } 93 94 const int64_t CallerIRSize; 95 const int64_t CalleeIRSize; 96 const int64_t CallerAndCalleeEdges; 97 98private: 99 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR); 100 101 MLInlineAdvisor *getAdvisor() const { 102 return static_cast<MLInlineAdvisor *>(Advisor); 103 }; 104}; 105 106} // namespace llvm 107 108#endif // LLVM_ANALYSIS_MLINLINEADVISOR_H 109