1//===- MLRegAllocEvictAdvisor.cpp - ML eviction advisor -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation of the ML eviction advisor and reward injection pass
10//
11//===----------------------------------------------------------------------===//
12
13#include "AllocationOrder.h"
14#include "RegAllocEvictionAdvisor.h"
15#include "RegAllocGreedy.h"
16#include "llvm/Analysis/MLModelRunner.h"
17#include "llvm/Analysis/TensorSpec.h"
18#if defined(LLVM_HAVE_TF_AOT_REGALLOCEVICTMODEL) || defined(LLVM_HAVE_TFLITE)
19#include "llvm/Analysis/ModelUnderTrainingRunner.h"
20#include "llvm/Analysis/NoInferenceModelRunner.h"
21#include "llvm/Analysis/Utils/TrainingLogger.h"
22#endif
23#include "MLRegallocEvictAdvisor.h"
24#include "llvm/Analysis/ReleaseModeModelRunner.h"
25#include "llvm/CodeGen/CalcSpillWeights.h"
26#include "llvm/CodeGen/LiveRegMatrix.h"
27#include "llvm/CodeGen/MachineBlockFrequencyInfo.h"
28#include "llvm/CodeGen/MachineFunction.h"
29#include "llvm/CodeGen/MachineLoopInfo.h"
30#include "llvm/CodeGen/MachineRegisterInfo.h"
31#include "llvm/CodeGen/Passes.h"
32#include "llvm/CodeGen/RegisterClassInfo.h"
33#include "llvm/CodeGen/VirtRegMap.h"
34#include "llvm/InitializePasses.h"
35#include "llvm/Pass.h"
36#include "llvm/PassRegistry.h"
37#include "llvm/Support/CommandLine.h"
38#include "llvm/Support/ErrorHandling.h"
39
40#include <array>
41#include <memory>
42
43using namespace llvm;
44
45#define DEBUG_TYPE "ml-regalloc"
46
47// Generated header in release (AOT) mode
48#if defined(LLVM_HAVE_TF_AOT_REGALLOCEVICTMODEL)
49#include "RegallocEvictModel.h"
50using CompiledModelType = RegallocEvictModel;
51#else
52using CompiledModelType = NoopSavedModelImpl;
53#endif
54
55// Options that only make sense in development mode
56#ifdef LLVM_HAVE_TFLITE
57#include "RegAllocScore.h"
58#include "llvm/Analysis/Utils/TFUtils.h"
59
60static cl::opt<std::string> TrainingLog(
61    "regalloc-training-log", cl::Hidden,
62    cl::desc("Training log for the register allocator eviction model"));
63
64static cl::opt<std::string> ModelUnderTraining(
65    "regalloc-model", cl::Hidden,
66    cl::desc("The model being trained for register allocation eviction"));
67
68static cl::opt<bool> EnableDevelopmentFeatures(
69    "regalloc-enable-development-features", cl::Hidden,
70    cl::desc("Whether or not to enable features under development for the ML "
71             "regalloc advisor"));
72
73#else
74static const bool EnableDevelopmentFeatures = false;
75#endif // #ifdef LLVM_HAVE_TFLITE
76
77extern cl::opt<unsigned> EvictInterferenceCutoff;
78
79/// The score injection pass.
80/// This pass calculates the score for a function and inserts it in the log, but
81/// this happens only in development mode. It's a no-op otherwise.
82namespace llvm {
83class RegAllocScoring : public MachineFunctionPass {
84public:
85  static char ID;
86
87  RegAllocScoring() : MachineFunctionPass(ID) {
88    initializeRegAllocScoringPass(*PassRegistry::getPassRegistry());
89  }
90
91  ~RegAllocScoring() override = default;
92
93  StringRef getPassName() const override {
94    return "Register Allocation Pass Scoring";
95  }
96
97  /// RegAllocReward analysis usage.
98  void getAnalysisUsage(AnalysisUsage &AU) const override {
99    AU.setPreservesAll();
100    AU.addRequired<RegAllocEvictionAdvisorAnalysis>();
101    AU.addRequired<RegAllocPriorityAdvisorAnalysis>();
102    AU.addRequired<MachineBlockFrequencyInfo>();
103    MachineFunctionPass::getAnalysisUsage(AU);
104  }
105
106  /// Performs this pass
107  bool runOnMachineFunction(MachineFunction &) override;
108};
109
110char RegAllocScoring::ID = 0;
111FunctionPass *createRegAllocScoringPass() { return new RegAllocScoring(); }
112
113} // namespace llvm
114
115INITIALIZE_PASS(RegAllocScoring, "regallocscoringpass",
116                "Register Allocation Scoring Pass", false, false)
117
118// ===================================
119// Common ML Advisor declarations
120// ===================================
121namespace {
122// The model can only accept a specified number of opcodes and will error it if
123// fed an opcode it hasn't seen before. This constant sets the current cutoff.
124static const int OpcodeValueCutoff = 17716;
125
126// Most features are as described above, so we'll reuse this vector in defining
127// them.
128static const std::vector<int64_t> PerLiveRangeShape{1, NumberOfInterferences};
129
130// --------------
131// Features table
132// --------------
133// For each interfering live range (incl. the candidate) we collect a number of
134// features. However, because the features are of different types (and because
135// of ML best practices), we organize the tensors per feature, not per
136// candidate. Each such tensor has a scalar value corresponding to the
137// interferring live range at that position, in the order in AllocationOrder.
138// The last position corresponds to the virt reg seeking allocation.
139// Exception to all that is the progression feature, which is just a scalar (see
140// its documentation for details).
141// Note on naming: the "_by_max" are normalized using the largest value of that
142// tensor, as observed in the current decision making stage (i.e. for the
143// current call to the advisor's tryFindEvictionCandidate)
144//
145// The feature list format: type, name, shape, documentation.
146// Note: we can really just use int64 and float, hence the modeling of some
147// bools as int64 values.
148#define RA_EVICT_FEATURES_LIST(M)                                              \
149  M(int64_t, mask, PerLiveRangeShape,                                          \
150    "boolean values, 0 for unavailable candidates (i.e. if a position is 0, "  \
151    "it "                                                                      \
152    "can't be evicted)")                                                       \
153  M(int64_t, is_free, PerLiveRangeShape,                                       \
154    "boolean values, 1 if this phys reg is actually free (no interferences)")  \
155  M(float, nr_urgent, PerLiveRangeShape,                                       \
156    "number of 'urgent' intervals, normalized. Urgent are those that are OK "  \
157    "to break cascades")                                                       \
158  M(float, nr_broken_hints, PerLiveRangeShape,                                 \
159    "if this position were evicted, how many broken hints would there be")     \
160  M(int64_t, is_hint, PerLiveRangeShape,                                       \
161    "is this a preferred phys reg for the candidate")                          \
162  M(int64_t, is_local, PerLiveRangeShape,                                      \
163    "is this live range local to a basic block")                               \
164  M(float, nr_rematerializable, PerLiveRangeShape,                             \
165    "nr rematerializable ranges")                                              \
166  M(float, nr_defs_and_uses, PerLiveRangeShape,                                \
167    "bb freq - weighed nr defs and uses")                                      \
168  M(float, weighed_reads_by_max, PerLiveRangeShape,                            \
169    "bb freq - weighed nr of reads, normalized")                               \
170  M(float, weighed_writes_by_max, PerLiveRangeShape,                           \
171    "bb feq - weighed nr of writes, normalized")                               \
172  M(float, weighed_read_writes_by_max, PerLiveRangeShape,                      \
173    "bb freq - weighed nr of uses that are both read and writes, normalized")  \
174  M(float, weighed_indvars_by_max, PerLiveRangeShape,                          \
175    "bb freq - weighed nr of uses that are indvars, normalized")               \
176  M(float, hint_weights_by_max, PerLiveRangeShape,                             \
177    "bb freq - weighed nr of uses that are hints, normalized")                 \
178  M(float, start_bb_freq_by_max, PerLiveRangeShape,                            \
179    "the freq in the start block, normalized")                                 \
180  M(float, end_bb_freq_by_max, PerLiveRangeShape,                              \
181    "freq of end block, normalized")                                           \
182  M(float, hottest_bb_freq_by_max, PerLiveRangeShape,                          \
183    "hottest BB freq, normalized")                                             \
184  M(float, liverange_size, PerLiveRangeShape,                                  \
185    "size (instr index diff) of the LR")                                       \
186  M(float, use_def_density, PerLiveRangeShape,                                 \
187    "the max weight, as computed by the manual heuristic")                     \
188  M(int64_t, max_stage, PerLiveRangeShape,                                     \
189    "largest stage of an interval in this LR")                                 \
190  M(int64_t, min_stage, PerLiveRangeShape,                                     \
191    "lowest stage of an interval in this LR")                                  \
192  M(float, progress, {1}, "ratio of current queue size to initial size")
193
194#ifdef LLVM_HAVE_TFLITE
195#define RA_EVICT_FIRST_DEVELOPMENT_FEATURE(M)                                  \
196  M(int64_t, instructions, InstructionsShape,                                  \
197    "Opcodes of the instructions covered by the eviction problem")
198
199#define RA_EVICT_REST_DEVELOPMENT_FEATURES(M)                                  \
200  M(int64_t, instructions_mapping, InstructionsMappingShape,                   \
201    "A binary matrix mapping LRs to instruction opcodes")                      \
202  M(float, mbb_frequencies, MBBFrequencyShape,                                 \
203    "A vector of machine basic block frequencies")                             \
204  M(int64_t, mbb_mapping, InstructionsShape,                                   \
205    "A vector of indicies mapping instructions to MBBs")
206#else
207#define RA_EVICT_FIRST_DEVELOPMENT_FEATURE(M)
208#define RA_EVICT_REST_DEVELOPMENT_FEATURES(M)
209#endif
210
211// The model learns to pick one of the mask == 1 interferences. This is the
212// name of the output tensor. The contract with the model is that the output
213// will be guaranteed to be to a mask == 1 position. Using a macro here to
214// avoid 'not used' warnings (and keep cond compilation to a minimum)
215#define DecisionName "index_to_evict"
216
217// Named features index.
218enum FeatureIDs {
219#define _FEATURE_IDX_SIMPLE(_, name, __, ___) name
220#define _FEATURE_IDX(A, B, C, D) _FEATURE_IDX_SIMPLE(A, B, C, D),
221  RA_EVICT_FEATURES_LIST(_FEATURE_IDX) FeatureCount,
222#ifdef LLVM_HAVE_TFLITE
223  RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_FEATURE_IDX_SIMPLE) = FeatureCount,
224#else
225  RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_FEATURE_IDX)
226#endif // #ifdef LLVM_HAVE_TFLITE
227  RA_EVICT_REST_DEVELOPMENT_FEATURES(_FEATURE_IDX) FeaturesWithDevelopmentCount
228#undef _FEATURE_IDX
229#undef _FEATURE_IDX_SIMPLE
230};
231
232// The ML advisor will typically have a sparse input to the evaluator, because
233// various phys regs won't be available. It's easier (maintenance-wise) to
234// bulk-reset the state of the evaluator each time we are about to use it
235// again.
236template <typename T> size_t getTotalSize(const std::vector<int64_t> &Shape) {
237  size_t Ret = sizeof(T);
238  for (const auto V : Shape)
239    Ret *= V;
240  return Ret;
241}
242
243void resetInputs(MLModelRunner &Runner) {
244#define _RESET(TYPE, NAME, SHAPE, __)                                          \
245  std::memset(Runner.getTensorUntyped(FeatureIDs::NAME), 0,                    \
246              getTotalSize<TYPE>(SHAPE));
247  RA_EVICT_FEATURES_LIST(_RESET)
248  if (EnableDevelopmentFeatures) {
249    RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_RESET)
250    RA_EVICT_REST_DEVELOPMENT_FEATURES(_RESET)
251#undef _RESET
252  }
253}
254
255// Per-live interval components that get aggregated into the feature values
256// that will be passed to the evaluator.
257struct LIFeatureComponents {
258  double R = 0;
259  double W = 0;
260  double RW = 0;
261  double IndVarUpdates = 0;
262  double HintWeights = 0.0;
263  int64_t NrDefsAndUses = 0;
264  float HottestBlockFreq = 0.0;
265  bool IsRemat = false;
266};
267
268using CandidateRegList =
269    std::array<std::pair<MCRegister, bool>, NumberOfInterferences>;
270using FeaturesListNormalizer =
271    llvm::SmallVector<float, FeatureIDs::FeatureCount>;
272
273/// The ML evictor (commonalities between release and development mode)
274class MLEvictAdvisor : public RegAllocEvictionAdvisor {
275public:
276  MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
277                 MLModelRunner *Runner, const MachineBlockFrequencyInfo &MBFI,
278                 const MachineLoopInfo &Loops);
279
280protected:
281  const RegAllocEvictionAdvisor &getDefaultAdvisor() const {
282    return static_cast<const RegAllocEvictionAdvisor &>(DefaultAdvisor);
283  }
284
285  // The assumption is that if the Runner could not be constructed, we emit-ed
286  // error, and we shouldn't be asking for it here.
287  const MLModelRunner &getRunner() const { return *Runner; }
288
289  /// This just calls Evaluate on the Runner, but in the development mode
290  /// case, if we're just capturing the log of the default advisor, it needs
291  /// to call the latter instead, so we need to pass all the necessary
292  /// parameters for it. In the development case, it will also log.
293  virtual int64_t
294  tryFindEvictionCandidatePosition(const LiveInterval &VirtReg,
295                                   const AllocationOrder &Order,
296                                   unsigned OrderLimit, uint8_t CostPerUseLimit,
297                                   const SmallVirtRegSet &FixedRegisters) const;
298
299  /// Load the features of the given VirtReg (allocated or not) at column Pos,
300  /// but if  that can't be evicted, return false instead.
301  bool
302  loadInterferenceFeatures(const LiveInterval &VirtReg, MCRegister PhysReg,
303                           bool IsHint, const SmallVirtRegSet &FixedRegisters,
304                           llvm::SmallVectorImpl<float> &Largest, size_t Pos,
305                           SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const;
306
307private:
308  static float getInitialQueueSize(const MachineFunction &MF);
309
310  MCRegister tryFindEvictionCandidate(
311      const LiveInterval &VirtReg, const AllocationOrder &Order,
312      uint8_t CostPerUseLimit,
313      const SmallVirtRegSet &FixedRegisters) const override;
314
315  void extractFeatures(const SmallVectorImpl<const LiveInterval *> &Intervals,
316                       llvm::SmallVectorImpl<float> &Largest, size_t Pos,
317                       int64_t IsHint, int64_t LocalIntfsCount, float NrUrgent,
318                       SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const;
319
320  // Point-in-time: we didn't learn this, so we always delegate to the
321  // default.
322  bool canEvictHintInterference(
323      const LiveInterval &VirtReg, MCRegister PhysReg,
324      const SmallVirtRegSet &FixedRegisters) const override {
325    return getDefaultAdvisor().canEvictHintInterference(VirtReg, PhysReg,
326                                                        FixedRegisters);
327  }
328
329  const LIFeatureComponents &
330  getLIFeatureComponents(const LiveInterval &LI) const;
331
332  // Hold on to a default advisor for:
333  // 1) the implementation of canEvictHintInterference, because we didn't
334  // learn that nuance yet; 2) for bootstrapping (logging) in the development
335  // mode case.
336  const DefaultEvictionAdvisor DefaultAdvisor;
337  MLModelRunner *const Runner;
338  const MachineBlockFrequencyInfo &MBFI;
339  const MachineLoopInfo &Loops;
340
341  // Indices of those features we don't want to normalize.
342  // This could be static and shared, but its initialization is non-trivial.
343  std::bitset<FeatureIDs::FeatureCount> DoNotNormalize;
344  const float InitialQSize;
345
346  using RegID = unsigned;
347  mutable DenseMap<RegID, LIFeatureComponents> CachedFeatures;
348};
349
350#define _DECL_FEATURES(type, name, shape, _)                                   \
351  TensorSpec::createSpec<type>(#name, shape),
352
353// ===================================
354// Release (AOT) - specifics
355// ===================================
356class ReleaseModeEvictionAdvisorAnalysis final
357    : public RegAllocEvictionAdvisorAnalysis {
358public:
359  ReleaseModeEvictionAdvisorAnalysis()
360      : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Release) {
361    if (EnableDevelopmentFeatures) {
362      InputFeatures = {RA_EVICT_FEATURES_LIST(
363          _DECL_FEATURES) RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_FEATURES)
364                           RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_FEATURES)};
365    } else {
366      InputFeatures = {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)};
367    }
368  }
369  // support for isa<> and dyn_cast.
370  static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
371    return R->getAdvisorMode() == AdvisorMode::Release;
372  }
373
374private:
375  std::vector<TensorSpec> InputFeatures;
376
377  void getAnalysisUsage(AnalysisUsage &AU) const override {
378    AU.addRequired<MachineBlockFrequencyInfo>();
379    AU.addRequired<MachineLoopInfo>();
380    RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
381  }
382
383  std::unique_ptr<RegAllocEvictionAdvisor>
384  getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
385    if (!Runner)
386      Runner = std::make_unique<ReleaseModeModelRunner<CompiledModelType>>(
387          MF.getFunction().getContext(), InputFeatures, DecisionName);
388    return std::make_unique<MLEvictAdvisor>(
389        MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
390        getAnalysis<MachineLoopInfo>());
391  }
392  std::unique_ptr<ReleaseModeModelRunner<CompiledModelType>> Runner;
393};
394
395// ===================================
396// Development mode-specifics
397// ===================================
398//
399// Features we log
400#ifdef LLVM_HAVE_TFLITE
401static const TensorSpec Output =
402    TensorSpec::createSpec<int64_t>(DecisionName, {1});
403static const TensorSpec Reward = TensorSpec::createSpec<float>("reward", {1});
404
405// Features we bind on the model. The tensor names have a prefix, and we also
406// need to include some tensors that are expected to be present by the
407// training algo.
408// TODO: can we just get rid of these?
409#define _DECL_TRAIN_FEATURES(type, name, shape, _)                             \
410  TensorSpec::createSpec<type>(std::string("action_") + #name, shape),
411
412class DevelopmentModeEvictAdvisor : public MLEvictAdvisor {
413public:
414  DevelopmentModeEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
415                              MLModelRunner *Runner,
416                              const MachineBlockFrequencyInfo &MBFI,
417                              const MachineLoopInfo &Loops, Logger *Log)
418      : MLEvictAdvisor(MF, RA, Runner, MBFI, Loops), Log(Log) {}
419
420private:
421  int64_t tryFindEvictionCandidatePosition(
422      const LiveInterval &VirtReg, const AllocationOrder &Order,
423      unsigned OrderLimit, uint8_t CostPerUseLimit,
424      const SmallVirtRegSet &FixedRegisters) const override;
425
426  Logger *const Log;
427};
428
429class DevelopmentModeEvictionAdvisorAnalysis final
430    : public RegAllocEvictionAdvisorAnalysis {
431public:
432  DevelopmentModeEvictionAdvisorAnalysis()
433      : RegAllocEvictionAdvisorAnalysis(AdvisorMode::Development) {
434    if (EnableDevelopmentFeatures) {
435      InputFeatures = {RA_EVICT_FEATURES_LIST(
436          _DECL_FEATURES) RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_FEATURES)
437                           RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_FEATURES)};
438      TrainingInputFeatures = {
439          RA_EVICT_FEATURES_LIST(_DECL_TRAIN_FEATURES)
440              RA_EVICT_FIRST_DEVELOPMENT_FEATURE(_DECL_TRAIN_FEATURES)
441                  RA_EVICT_REST_DEVELOPMENT_FEATURES(_DECL_TRAIN_FEATURES)
442                      TensorSpec::createSpec<float>("action_discount", {1}),
443          TensorSpec::createSpec<int32_t>("action_step_type", {1}),
444          TensorSpec::createSpec<float>("action_reward", {1})};
445    } else {
446      InputFeatures = {RA_EVICT_FEATURES_LIST(_DECL_FEATURES)};
447      TrainingInputFeatures = {
448          RA_EVICT_FEATURES_LIST(_DECL_TRAIN_FEATURES)
449              TensorSpec::createSpec<float>("action_discount", {1}),
450          TensorSpec::createSpec<int32_t>("action_step_type", {1}),
451          TensorSpec::createSpec<float>("action_reward", {1})};
452    }
453  }
454  // support for isa<> and dyn_cast.
455  static bool classof(const RegAllocEvictionAdvisorAnalysis *R) {
456    return R->getAdvisorMode() == AdvisorMode::Development;
457  }
458
459  void logRewardIfNeeded(const MachineFunction &MF,
460                         llvm::function_ref<float()> GetReward) override {
461    if (!Log)
462      return;
463    // The function pass manager would run all the function passes for a
464    // function, so we assume the last context belongs to this function. If
465    // this invariant ever changes, we can implement at that time switching
466    // contexts. At this point, it'd be an error
467    if (Log->currentContext() != MF.getName()) {
468      MF.getFunction().getContext().emitError(
469          "The training log context shouldn't have had changed.");
470    }
471    if (Log->hasObservationInProgress())
472      Log->logReward<float>(GetReward());
473  }
474
475private:
476  std::vector<TensorSpec> InputFeatures;
477  std::vector<TensorSpec> TrainingInputFeatures;
478
479  void getAnalysisUsage(AnalysisUsage &AU) const override {
480    AU.addRequired<MachineBlockFrequencyInfo>();
481    AU.addRequired<MachineLoopInfo>();
482    RegAllocEvictionAdvisorAnalysis::getAnalysisUsage(AU);
483  }
484
485  bool doInitialization(Module &M) override {
486    LLVMContext &Ctx = M.getContext();
487    if (ModelUnderTraining.empty() && TrainingLog.empty()) {
488      Ctx.emitError("Regalloc development mode should be requested with at "
489                    "least logging enabled and/or a training model");
490      return false;
491    }
492    if (ModelUnderTraining.empty())
493      Runner = std::make_unique<NoInferenceModelRunner>(Ctx, InputFeatures);
494    else
495      Runner = ModelUnderTrainingRunner::createAndEnsureValid(
496          Ctx, ModelUnderTraining, DecisionName, TrainingInputFeatures);
497    if (!Runner) {
498      Ctx.emitError("Regalloc: could not set up the model runner");
499      return false;
500    }
501    if (TrainingLog.empty())
502      return false;
503    std::error_code EC;
504    auto OS = std::make_unique<raw_fd_ostream>(TrainingLog, EC);
505    if (EC) {
506      M.getContext().emitError(EC.message() + ":" + TrainingLog);
507      return false;
508    }
509    std::vector<TensorSpec> LFS = InputFeatures;
510    if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(Runner.get()))
511      append_range(LFS, MUTR->extraOutputsForLoggingSpecs());
512    // We always log the output; in particular, if we're not evaluating, we
513    // don't have an output spec json file. That's why we handle the
514    // 'normal' output separately.
515    LFS.push_back(Output);
516
517    Log = std::make_unique<Logger>(std::move(OS), LFS, Reward,
518                                   /*IncludeReward*/ true);
519    return false;
520  }
521
522  std::unique_ptr<RegAllocEvictionAdvisor>
523  getAdvisor(const MachineFunction &MF, const RAGreedy &RA) override {
524    if (!Runner)
525      return nullptr;
526    if (Log)
527      Log->switchContext(MF.getName());
528    return std::make_unique<DevelopmentModeEvictAdvisor>(
529        MF, RA, Runner.get(), getAnalysis<MachineBlockFrequencyInfo>(),
530        getAnalysis<MachineLoopInfo>(), Log.get());
531  }
532
533  std::unique_ptr<MLModelRunner> Runner;
534  std::unique_ptr<Logger> Log;
535};
536
537#endif //#ifdef LLVM_HAVE_TFLITE
538} // namespace
539
540float MLEvictAdvisor::getInitialQueueSize(const MachineFunction &MF) {
541  auto &MRI = MF.getRegInfo();
542  float Ret = 0.0;
543  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
544    Register Reg = Register::index2VirtReg(I);
545    if (MRI.reg_nodbg_empty(Reg))
546      continue;
547    ++Ret;
548  }
549  return Ret;
550}
551
552MLEvictAdvisor::MLEvictAdvisor(const MachineFunction &MF, const RAGreedy &RA,
553                               MLModelRunner *Runner,
554                               const MachineBlockFrequencyInfo &MBFI,
555                               const MachineLoopInfo &Loops)
556    : RegAllocEvictionAdvisor(MF, RA), DefaultAdvisor(MF, RA),
557      Runner(std::move(Runner)), MBFI(MBFI), Loops(Loops),
558      InitialQSize(MLEvictAdvisor::getInitialQueueSize(MF)) {
559  assert(this->Runner);
560  DoNotNormalize.set(FeatureIDs::mask);
561  DoNotNormalize.set(FeatureIDs::is_free);
562  DoNotNormalize.set(FeatureIDs::is_hint);
563  DoNotNormalize.set(FeatureIDs::is_local);
564  DoNotNormalize.set(FeatureIDs::min_stage);
565  DoNotNormalize.set(FeatureIDs::max_stage);
566  DoNotNormalize.set(FeatureIDs::progress);
567}
568
569int64_t MLEvictAdvisor::tryFindEvictionCandidatePosition(
570    const LiveInterval &, const AllocationOrder &, unsigned, uint8_t,
571    const SmallVirtRegSet &) const {
572  int64_t Ret = Runner->evaluate<int64_t>();
573  assert(Ret >= 0);
574  assert(Ret <= CandidateVirtRegPos);
575  return Ret;
576}
577
578bool MLEvictAdvisor::loadInterferenceFeatures(
579    const LiveInterval &VirtReg, MCRegister PhysReg, bool IsHint,
580    const SmallVirtRegSet &FixedRegisters,
581    llvm::SmallVectorImpl<float> &Largest, size_t Pos,
582    llvm::SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const {
583  // It is only possible to evict virtual register interference.
584  if (Matrix->checkInterference(VirtReg, PhysReg) > LiveRegMatrix::IK_VirtReg) {
585    // leave unavailable
586    return false;
587  }
588
589  const bool IsLocal = LIS->intervalIsInOneMBB(VirtReg);
590  int64_t LocalIntfs = 0;
591  float NrUrgent = 0.0f;
592
593  // The cascade tracking is the same as in the default advisor
594  unsigned Cascade = RA.getExtraInfo().getCascadeOrCurrentNext(VirtReg.reg());
595
596  SmallVector<const LiveInterval *, MaxInterferences> InterferingIntervals;
597  for (MCRegUnitIterator Units(PhysReg, TRI); Units.isValid(); ++Units) {
598    LiveIntervalUnion::Query &Q = Matrix->query(VirtReg, *Units);
599    // Different from the default heuristic, we don't make any assumptions
600    // about what having more than 10 results in the query may mean.
601    const auto &IFIntervals = Q.interferingVRegs(EvictInterferenceCutoff);
602    if (IFIntervals.empty() && InterferingIntervals.empty())
603      continue;
604    if (IFIntervals.size() >= EvictInterferenceCutoff)
605      return false;
606    InterferingIntervals.append(IFIntervals.begin(), IFIntervals.end());
607    for (const LiveInterval *Intf : reverse(IFIntervals)) {
608      assert(Intf->reg().isVirtual() &&
609             "Only expecting virtual register interference from query");
610      // This is the same set of legality checks as in the default case: don't
611      // try to evict fixed regs or 'done' ones. Also don't break cascades,
612      // except in the urgent case, with the same nuances used in the default
613      // heuristic.
614      // We could try sharing this between the advisors, but it may end up
615      // more complex than it is right now.
616      if (FixedRegisters.count(Intf->reg()))
617        return false;
618      if (RA.getExtraInfo().getStage(*Intf) == RS_Done)
619        return false;
620      bool Urgent =
621          !VirtReg.isSpillable() &&
622          (Intf->isSpillable() ||
623           RegClassInfo.getNumAllocatableRegs(MRI->getRegClass(VirtReg.reg())) <
624               RegClassInfo.getNumAllocatableRegs(
625                   MRI->getRegClass(Intf->reg())));
626      // Only evict older cascades or live ranges without a cascade.
627      unsigned IntfCascade = RA.getExtraInfo().getCascade(Intf->reg());
628      if (Cascade <= IntfCascade) {
629        if (!Urgent)
630          return false;
631        ++NrUrgent;
632      }
633
634      LocalIntfs += (IsLocal && LIS->intervalIsInOneMBB(*Intf) &&
635                     (!EnableLocalReassign || !canReassign(*Intf, PhysReg)));
636    }
637  }
638  // OK, so if we made it this far, this LR is an eviction candidate, load its
639  // features.
640  extractFeatures(InterferingIntervals, Largest, Pos, IsHint, LocalIntfs,
641                  NrUrgent, LRPosInfo);
642  return true;
643}
644
645MCRegister MLEvictAdvisor::tryFindEvictionCandidate(
646    const LiveInterval &VirtReg, const AllocationOrder &Order,
647    uint8_t CostPerUseLimit, const SmallVirtRegSet &FixedRegisters) const {
648  auto MaybeOrderLimit = getOrderLimit(VirtReg, Order, CostPerUseLimit);
649  if (!MaybeOrderLimit)
650    return MCRegister::NoRegister;
651  unsigned OrderLimit = *MaybeOrderLimit;
652
653  // The heuristic sets initial costs such as, if CostPerUseLimit is
654  // max<uint8_t>, then any of the costs of the legally-evictable intervals
655  // would be lower. When that happens, one of those will be selected.
656  // Therefore, we allow the candidate be selected, unless the candidate is
657  // unspillable, in which case it would be incorrect to not find a register
658  // for it.
659  const bool MustFindEviction =
660      (!VirtReg.isSpillable() && CostPerUseLimit == static_cast<uint8_t>(~0u));
661  // Number of available candidates - if 0, no need to continue.
662  size_t Available = 0;
663  // Make sure we don't have leftover partial state from an attempt where we
664  // had no available candidates and bailed out early.
665  resetInputs(*Runner);
666
667  // Track the index->register mapping because AllocationOrder doesn't do that
668  // and we'd have to scan it.
669  // Also track their mask, to write asserts/debug.
670  CandidateRegList Regs;
671  Regs.fill({0, false});
672
673  // Track the largest value of features seen during this eviction session. We
674  // only normalize (some of) the float features, but it's just simpler to
675  // dimension 'Largest' to all the features, especially since we have the
676  // 'DoNotNormalize' list.
677  FeaturesListNormalizer Largest(FeatureIDs::FeatureCount, 0.0);
678
679  // Same overal idea as in the default eviction policy - we visit the values
680  // of AllocationOrder one at a time. If it's not legally available, we mask
681  // off the corresponding feature column (==do nothing because we already
682  // reset all the features to 0) Use Pos to capture the column we load
683  // features at - in AllocationOrder order.
684  size_t Pos = 0;
685  SmallVector<LRStartEndInfo, NumberOfInterferences> LRPosInfo;
686  for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit); I != E;
687       ++I, ++Pos) {
688    MCRegister PhysReg = *I;
689    assert(!Regs[Pos].second);
690    assert(PhysReg);
691    if (!canAllocatePhysReg(CostPerUseLimit, PhysReg)) {
692      continue;
693    }
694    if (loadInterferenceFeatures(VirtReg, PhysReg, I.isHint(), FixedRegisters,
695                                 Largest, Pos, LRPosInfo)) {
696      ++Available;
697      Regs[Pos] = std::make_pair(PhysReg, true);
698    }
699  }
700  if (Available == 0) {
701    // Nothing to decide, nothing to learn.
702    assert(!MustFindEviction);
703    return MCRegister::NoRegister;
704  }
705  const size_t ValidPosLimit = Pos;
706  // If we must find eviction, the candidate should be masked out of the
707  // decision making process.
708  Regs[CandidateVirtRegPos].second = !MustFindEviction;
709  if (!MustFindEviction)
710    extractFeatures(SmallVector<const LiveInterval *, 1>(1, &VirtReg), Largest,
711                    CandidateVirtRegPos, /*IsHint*/ 0,
712                    /*LocalIntfsCount*/ 0,
713                    /*NrUrgent*/ 0.0, LRPosInfo);
714  assert(InitialQSize > 0.0 && "We couldn't have gotten here if we had "
715                               "nothing to allocate initially.");
716#ifdef LLVM_HAVE_TFLITE
717  if (EnableDevelopmentFeatures) {
718    extractInstructionFeatures(
719        LRPosInfo, Runner,
720        [this](SlotIndex InputIndex) -> int {
721          auto *CurrentMachineInstruction =
722              LIS->getInstructionFromIndex(InputIndex);
723          if (!CurrentMachineInstruction) {
724            return -1;
725          }
726          return CurrentMachineInstruction->getOpcode();
727        },
728        [this](SlotIndex InputIndex) -> float {
729          auto *CurrentMachineInstruction =
730              LIS->getInstructionFromIndex(InputIndex);
731          return MBFI.getBlockFreqRelativeToEntryBlock(
732              CurrentMachineInstruction->getParent());
733        },
734        [this](SlotIndex InputIndex) -> MachineBasicBlock * {
735          auto *CurrentMachineInstruction =
736              LIS->getInstructionFromIndex(InputIndex);
737          return CurrentMachineInstruction->getParent();
738        },
739        FeatureIDs::instructions, FeatureIDs::instructions_mapping,
740        FeatureIDs::mbb_frequencies, FeatureIDs::mbb_mapping,
741        LIS->getSlotIndexes()->getLastIndex());
742  }
743#endif // #ifdef LLVM_HAVE_TFLITE
744  // Normalize the features.
745  for (auto &V : Largest)
746    V = V ? V : 1.0;
747  for (size_t FeatureIndex = 0; FeatureIndex < FeatureIDs::FeatureCount;
748       ++FeatureIndex) {
749    if (DoNotNormalize.test(FeatureIndex))
750      continue;
751    for (size_t Pos = 0; Pos < NumberOfInterferences; ++Pos) {
752      Runner->getTensor<float>(FeatureIndex)[Pos] /= Largest[FeatureIndex];
753    }
754  }
755  *Runner->getTensor<float>(FeatureIDs::progress) =
756      static_cast<float>(RA.getQueueSize()) / InitialQSize;
757
758  // Get a decision.
759  size_t CandidatePos = tryFindEvictionCandidatePosition(
760      VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
761  // The contract with the ML side is that CandidatePos is mask == 1 (i.e.
762  // Regs[CandidatePos].second)
763  assert(Regs[CandidatePos].second);
764  if (CandidatePos == CandidateVirtRegPos) {
765    assert(!MustFindEviction);
766    return MCRegister::NoRegister;
767  }
768  assert(CandidatePos < ValidPosLimit);
769  (void)ValidPosLimit;
770  return Regs[CandidatePos].first;
771}
772
773const LIFeatureComponents &
774MLEvictAdvisor::getLIFeatureComponents(const LiveInterval &LI) const {
775  RegID ID = LI.reg().id();
776  LIFeatureComponents Empty;
777  auto I = CachedFeatures.insert(std::make_pair(ID, Empty));
778  LIFeatureComponents &Ret = I.first->getSecond();
779  if (!I.second)
780    return Ret;
781
782  SmallPtrSet<MachineInstr *, 8> Visited;
783  const TargetRegisterInfo &TRI = *MF.getSubtarget().getRegisterInfo();
784
785  for (MachineRegisterInfo::reg_instr_nodbg_iterator
786           I = MRI->reg_instr_nodbg_begin(LI.reg()),
787           E = MRI->reg_instr_nodbg_end();
788       I != E;) {
789    MachineInstr *MI = &*(I++);
790
791    ++Ret.NrDefsAndUses;
792    if (!Visited.insert(MI).second)
793      continue;
794
795    if (MI->isIdentityCopy() || MI->isImplicitDef())
796      continue;
797
798    bool Reads, Writes;
799    std::tie(Reads, Writes) = MI->readsWritesVirtualRegister(LI.reg());
800
801    float Freq = MBFI.getBlockFreqRelativeToEntryBlock(MI->getParent());
802    Ret.HottestBlockFreq = std::max(Freq, Ret.HottestBlockFreq);
803
804    Ret.R += (Reads && !Writes) * Freq;
805    Ret.W += (!Reads && Writes) * Freq;
806    Ret.RW += (Reads && Writes) * Freq;
807
808    auto *MBB = MI->getParent();
809    auto *Loop = Loops.getLoopFor(MBB);
810    bool IsExiting = Loop ? Loop->isLoopExiting(MBB) : false;
811
812    if (Writes && IsExiting && LIS->isLiveOutOfMBB(LI, MBB))
813      Ret.IndVarUpdates += Freq;
814
815    if (MI->isCopy() && VirtRegAuxInfo::copyHint(MI, LI.reg(), TRI, *MRI))
816      Ret.HintWeights += Freq;
817  }
818  Ret.IsRemat = VirtRegAuxInfo::isRematerializable(
819      LI, *LIS, *VRM, *MF.getSubtarget().getInstrInfo());
820  return Ret;
821}
822
823// Overall, this currently mimics what we do for weight calculation, but instead
824// of accummulating the various features, we keep them separate.
825void MLEvictAdvisor::extractFeatures(
826    const SmallVectorImpl<const LiveInterval *> &Intervals,
827    llvm::SmallVectorImpl<float> &Largest, size_t Pos, int64_t IsHint,
828    int64_t LocalIntfsCount, float NrUrgent,
829    SmallVectorImpl<LRStartEndInfo> &LRPosInfo) const {
830  int64_t NrDefsAndUses = 0;
831  int64_t NrBrokenHints = 0;
832  double R = 0.0;
833  double W = 0.0;
834  double RW = 0.0;
835  double IndVarUpdates = 0.0;
836  double HintWeights = 0.0;
837  float StartBBFreq = 0.0;
838  float EndBBFreq = 0.0;
839  float HottestBlockFreq = 0.0;
840  int32_t NrRematerializable = 0;
841  float TotalWeight = 0.0;
842
843  SlotIndex EndSI = LIS->getSlotIndexes()->getZeroIndex();
844  SlotIndex StartSI = LIS->getSlotIndexes()->getLastIndex();
845  int64_t MaxStage = 0;
846  int64_t MinStage =
847      Intervals.empty() ? 0 : std::numeric_limits<int64_t>::max();
848
849  for (const auto *L : Intervals) {
850    const LiveInterval &LI = *L;
851    MaxStage = std::max<int64_t>(
852        MaxStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
853    MinStage = std::min<int64_t>(
854        MinStage, static_cast<int64_t>(RA.getExtraInfo().getStage(LI)));
855
856    TotalWeight = std::max(TotalWeight, LI.weight());
857
858    if (LI.beginIndex() < StartSI)
859      StartSI = LI.beginIndex();
860
861    if (LI.endIndex() > EndSI)
862      EndSI = LI.endIndex();
863    const LIFeatureComponents &LIFC = getLIFeatureComponents(LI);
864    NrBrokenHints += VRM->hasPreferredPhys(LI.reg());
865
866    NrDefsAndUses += LIFC.NrDefsAndUses;
867    HottestBlockFreq = std::max(HottestBlockFreq, LIFC.HottestBlockFreq);
868    R += LIFC.R;
869    W += LIFC.W;
870    RW += LIFC.RW;
871
872    IndVarUpdates += LIFC.IndVarUpdates;
873
874    HintWeights += LIFC.HintWeights;
875    NrRematerializable += LIFC.IsRemat;
876
877    if (EnableDevelopmentFeatures) {
878      for (auto CurrentSegment : LI) {
879        LRPosInfo.push_back(
880            LRStartEndInfo{CurrentSegment.start, CurrentSegment.end, Pos});
881      }
882    }
883  }
884  size_t Size = 0;
885  if (!Intervals.empty()) {
886    StartBBFreq =
887        MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(StartSI));
888    if (EndSI >= LIS->getSlotIndexes()->getLastIndex())
889      EndSI = LIS->getSlotIndexes()->getLastIndex().getPrevIndex();
890    EndBBFreq =
891        MBFI.getBlockFreqRelativeToEntryBlock(LIS->getMBBFromIndex(EndSI));
892    Size = StartSI.distance(EndSI);
893  }
894  // Set the features at the column 'Pos'.
895#define SET(ID, TYPE, VAL)                                                     \
896  do {                                                                         \
897    Runner->getTensor<TYPE>(FeatureIDs::ID)[Pos] = static_cast<TYPE>(VAL);     \
898    if (!DoNotNormalize.test(FeatureIDs::ID))                                  \
899      Largest[FeatureIDs::ID] =                                                \
900          std::max(Largest[FeatureIDs::ID], static_cast<float>(VAL));          \
901  } while (false)
902  SET(mask, int64_t, 1);
903  SET(is_free, int64_t, Intervals.empty());
904  SET(nr_urgent, float, NrUrgent);
905  SET(nr_broken_hints, float, NrBrokenHints);
906  SET(is_hint, int64_t, IsHint);
907  SET(is_local, int64_t, LocalIntfsCount);
908  SET(nr_rematerializable, float, NrRematerializable);
909  SET(nr_defs_and_uses, float, NrDefsAndUses);
910  SET(weighed_reads_by_max, float, R);
911  SET(weighed_writes_by_max, float, W);
912  SET(weighed_read_writes_by_max, float, RW);
913  SET(weighed_indvars_by_max, float, IndVarUpdates);
914  SET(hint_weights_by_max, float, HintWeights);
915  SET(start_bb_freq_by_max, float, StartBBFreq);
916  SET(end_bb_freq_by_max, float, EndBBFreq);
917  SET(hottest_bb_freq_by_max, float, HottestBlockFreq);
918  SET(liverange_size, float, Size);
919  SET(use_def_density, float, TotalWeight);
920  SET(max_stage, int64_t, MaxStage);
921  SET(min_stage, int64_t, MinStage);
922#undef SET
923}
924
925void extractInstructionFeatures(
926    SmallVectorImpl<LRStartEndInfo> &LRPosInfo, MLModelRunner *RegallocRunner,
927    function_ref<int(SlotIndex)> GetOpcode,
928    function_ref<float(SlotIndex)> GetMBBFreq,
929    function_ref<MachineBasicBlock *(SlotIndex)> GetMBBReference,
930    const int InstructionsIndex, const int InstructionsMappingIndex,
931    const int MBBFreqIndex, const int MBBMappingIndex,
932    const SlotIndex LastIndex) {
933  // This function extracts instruction based features relevant to the eviction
934  // problem currently being solved. This function ends up extracting two
935  // tensors.
936  // 1 - A vector of size max instruction count. It contains the opcodes of the
937  // instructions spanned by all the intervals in the current instance of the
938  // eviction problem.
939  // 2 - A binary mapping matrix of size (LR count * max
940  // instruction count) which maps where the LRs are live to the actual opcodes
941  // for which they are live.
942  // 3 - A vector of size max supported MBB count storing MBB frequencies,
943  // encompassing all of the MBBs covered by the eviction problem.
944  // 4 - A vector of size max instruction count of indices to members of the MBB
945  // frequency vector, mapping each instruction to its associated MBB.
946
947  // Start off by sorting the segments based on the beginning slot index.
948  std::sort(
949      LRPosInfo.begin(), LRPosInfo.end(),
950      [](LRStartEndInfo A, LRStartEndInfo B) { return A.Begin < B.Begin; });
951  size_t InstructionIndex = 0;
952  size_t CurrentSegmentIndex = 0;
953  SlotIndex CurrentIndex = LRPosInfo[0].Begin;
954  std::map<MachineBasicBlock *, size_t> VisitedMBBs;
955  size_t CurrentMBBIndex = 0;
956  // This loop processes all the segments sequentially by starting at the
957  // beginning slot index of the first segment, iterating through all the slot
958  // indices before the end slot index of that segment (while checking for
959  // overlaps with segments that start at greater slot indices). After hitting
960  // that end index, the current segment being processed gets bumped until they
961  // are all processed or the max instruction count is hit, where everything is
962  // just truncated.
963  while (true) {
964    // If the index that we are currently at is within the current segment and
965    // we haven't hit the max instruction count, continue processing the current
966    // segment.
967    while (CurrentIndex <= LRPosInfo[CurrentSegmentIndex].End &&
968           InstructionIndex < ModelMaxSupportedInstructionCount) {
969      int CurrentOpcode = GetOpcode(CurrentIndex);
970      // If the current machine instruction is null, skip it
971      if (CurrentOpcode == -1) {
972        // If we're currently at the last index in the SlotIndex analysis,
973        // we can't go any further, so return from the function
974        if (CurrentIndex >= LastIndex) {
975          return;
976        }
977        CurrentIndex = CurrentIndex.getNextIndex();
978        continue;
979      }
980      MachineBasicBlock *CurrentMBBReference = GetMBBReference(CurrentIndex);
981      if (VisitedMBBs.count(CurrentMBBReference) == 0) {
982        VisitedMBBs[CurrentMBBReference] = CurrentMBBIndex;
983        ++CurrentMBBIndex;
984      }
985      extractMBBFrequency(CurrentIndex, InstructionIndex, VisitedMBBs,
986                          GetMBBFreq, CurrentMBBReference, RegallocRunner,
987                          MBBFreqIndex, MBBMappingIndex);
988      // Current code assumes we're not going to get any disjointed segments
989      assert(LRPosInfo[CurrentSegmentIndex].Begin <= CurrentIndex);
990      RegallocRunner->getTensor<int64_t>(InstructionsIndex)[InstructionIndex] =
991          CurrentOpcode < OpcodeValueCutoff ? CurrentOpcode : 0;
992      // set value in the binary mapping matrix for the current instruction
993      auto CurrentSegmentPosition = LRPosInfo[CurrentSegmentIndex].Pos;
994      RegallocRunner->getTensor<int64_t>(
995          InstructionsMappingIndex)[CurrentSegmentPosition *
996                                        ModelMaxSupportedInstructionCount +
997                                    InstructionIndex] = 1;
998      // All of the segments are sorted based on the beginning slot index, but
999      // this doesn't mean that the beginning slot index of the next segment is
1000      // after the end segment of the one being currently processed. This while
1001      // loop checks for overlapping segments and modifies the portion of the
1002      // column in the mapping matrix for the currently processed instruction
1003      // for the LR it is checking. Also make sure that the beginning of the
1004      // current segment we're checking for overlap in is less than the current
1005      // index, otherwise we're done checking overlaps.
1006      size_t OverlapCheckCurrentSegment = CurrentSegmentIndex + 1;
1007      while (OverlapCheckCurrentSegment < LRPosInfo.size() &&
1008             LRPosInfo[OverlapCheckCurrentSegment].Begin <= CurrentIndex) {
1009        auto OverlapCurrentSegmentPosition =
1010            LRPosInfo[OverlapCheckCurrentSegment].Pos;
1011        if (LRPosInfo[OverlapCheckCurrentSegment].End >= CurrentIndex) {
1012          RegallocRunner->getTensor<int64_t>(
1013              InstructionsMappingIndex)[OverlapCurrentSegmentPosition *
1014                                            ModelMaxSupportedInstructionCount +
1015                                        InstructionIndex] = 1;
1016        }
1017        ++OverlapCheckCurrentSegment;
1018      }
1019      ++InstructionIndex;
1020      if (CurrentIndex >= LastIndex) {
1021        return;
1022      }
1023      CurrentIndex = CurrentIndex.getNextIndex();
1024    }
1025    // if we've just finished processing through the last segment or if we've
1026    // hit the maximum number of instructions, break out of the loop.
1027    if (CurrentSegmentIndex == LRPosInfo.size() - 1 ||
1028        InstructionIndex >= ModelMaxSupportedInstructionCount) {
1029      break;
1030    }
1031    // If the segments are not overlapping, we need to move to the beginning
1032    // index of the next segment to avoid having instructions not attached to
1033    // any register.
1034    if (LRPosInfo[CurrentSegmentIndex + 1].Begin >
1035        LRPosInfo[CurrentSegmentIndex].End) {
1036      CurrentIndex = LRPosInfo[CurrentSegmentIndex + 1].Begin;
1037    }
1038    ++CurrentSegmentIndex;
1039  }
1040}
1041
1042void extractMBBFrequency(const SlotIndex CurrentIndex,
1043                         const size_t CurrentInstructionIndex,
1044                         std::map<MachineBasicBlock *, size_t> &VisitedMBBs,
1045                         function_ref<float(SlotIndex)> GetMBBFreq,
1046                         MachineBasicBlock *CurrentMBBReference,
1047                         MLModelRunner *RegallocRunner, const int MBBFreqIndex,
1048                         const int MBBMappingIndex) {
1049  size_t CurrentMBBIndex = VisitedMBBs[CurrentMBBReference];
1050  float CurrentMBBFreq = GetMBBFreq(CurrentIndex);
1051  if (CurrentMBBIndex < ModelMaxSupportedMBBCount) {
1052    RegallocRunner->getTensor<float>(MBBFreqIndex)[CurrentMBBIndex] =
1053        CurrentMBBFreq;
1054    RegallocRunner->getTensor<int64_t>(
1055        MBBMappingIndex)[CurrentInstructionIndex] = CurrentMBBIndex;
1056  }
1057}
1058
1059// Development mode-specific implementations
1060#ifdef LLVM_HAVE_TFLITE
1061
1062RegAllocEvictionAdvisorAnalysis *llvm::createDevelopmentModeAdvisor() {
1063  return new DevelopmentModeEvictionAdvisorAnalysis();
1064}
1065
1066int64_t DevelopmentModeEvictAdvisor::tryFindEvictionCandidatePosition(
1067    const LiveInterval &VirtReg, const AllocationOrder &Order,
1068    unsigned OrderLimit, uint8_t CostPerUseLimit,
1069    const SmallVirtRegSet &FixedRegisters) const {
1070  int64_t Ret = 0;
1071  if (isa<ModelUnderTrainingRunner>(getRunner())) {
1072    Ret = MLEvictAdvisor::tryFindEvictionCandidatePosition(
1073        VirtReg, Order, OrderLimit, CostPerUseLimit, FixedRegisters);
1074  } else {
1075    MCRegister PhysReg = getDefaultAdvisor().tryFindEvictionCandidate(
1076        VirtReg, Order, CostPerUseLimit, FixedRegisters);
1077    // Find the index of the selected PhysReg. We need it for logging,
1078    // otherwise this is wasted cycles (but so would starting development mode
1079    // without a model nor logging)
1080    if (!PhysReg)
1081      Ret = CandidateVirtRegPos;
1082    else
1083      for (auto I = Order.begin(), E = Order.getOrderLimitEnd(OrderLimit);
1084           I != E; ++I, ++Ret)
1085        if (*I == PhysReg)
1086          break;
1087  }
1088  if (TrainingLog.empty())
1089    return Ret;
1090  // TODO(mtrofin): when we support optional rewards, this can go away. In the
1091  // meantime, we log the "pretend" reward (0) for the previous observation
1092  // before starting a new one.
1093  if (Log->hasObservationInProgress())
1094    Log->logReward<float>(0.0);
1095
1096  Log->startObservation();
1097  size_t CurrentFeature = 0;
1098  size_t FeatureCount = EnableDevelopmentFeatures
1099                            ? FeatureIDs::FeaturesWithDevelopmentCount
1100                            : FeatureIDs::FeatureCount;
1101  for (; CurrentFeature < FeatureCount; ++CurrentFeature) {
1102    Log->logTensorValue(CurrentFeature,
1103                        reinterpret_cast<const char *>(
1104                            getRunner().getTensorUntyped(CurrentFeature)));
1105  }
1106  if (auto *MUTR = dyn_cast<ModelUnderTrainingRunner>(&getRunner()))
1107    for (size_t I = 0; I < MUTR->extraOutputsForLoggingSpecs().size();
1108         ++I, ++CurrentFeature)
1109      Log->logTensorValue(
1110          CurrentFeature,
1111          reinterpret_cast<const char *>(MUTR->getUntypedExtraOutputValue(I)));
1112  // The output is right after the features and the extra outputs
1113  Log->logTensorValue(CurrentFeature, reinterpret_cast<const char *>(&Ret));
1114  Log->endObservation();
1115  return Ret;
1116}
1117
1118bool RegAllocScoring::runOnMachineFunction(MachineFunction &MF) {
1119  std::optional<float> CachedReward;
1120  auto GetReward = [&]() {
1121    if (!CachedReward)
1122      CachedReward = static_cast<float>(
1123          calculateRegAllocScore(MF, getAnalysis<MachineBlockFrequencyInfo>())
1124              .getScore());
1125    return *CachedReward;
1126  };
1127
1128  getAnalysis<RegAllocEvictionAdvisorAnalysis>().logRewardIfNeeded(MF,
1129                                                                   GetReward);
1130  getAnalysis<RegAllocPriorityAdvisorAnalysis>().logRewardIfNeeded(MF,
1131                                                                   GetReward);
1132  return false;
1133}
1134#endif // #ifdef LLVM_HAVE_TFLITE
1135
1136RegAllocEvictionAdvisorAnalysis *llvm::createReleaseModeAdvisor() {
1137  return new ReleaseModeEvictionAdvisorAnalysis();
1138}
1139
1140// In all cases except development mode, we don't need scoring.
1141#if !defined(LLVM_HAVE_TFLITE)
1142bool RegAllocScoring::runOnMachineFunction(MachineFunction &) { return false; }
1143#endif
1144