1//===- RegisterPressure.h - Dynamic Register Pressure -----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the RegisterPressure class which can be used to track
10// MachineInstr level register pressure.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef LLVM_CODEGEN_REGISTERPRESSURE_H
15#define LLVM_CODEGEN_REGISTERPRESSURE_H
16
17#include "llvm/ADT/ArrayRef.h"
18#include "llvm/ADT/SmallVector.h"
19#include "llvm/ADT/SparseSet.h"
20#include "llvm/CodeGen/MachineBasicBlock.h"
21#include "llvm/CodeGen/SlotIndexes.h"
22#include "llvm/CodeGen/TargetRegisterInfo.h"
23#include "llvm/MC/LaneBitmask.h"
24#include <cassert>
25#include <cstddef>
26#include <cstdint>
27#include <cstdlib>
28#include <limits>
29#include <vector>
30
31namespace llvm {
32
33class LiveIntervals;
34class MachineFunction;
35class MachineInstr;
36class MachineRegisterInfo;
37class RegisterClassInfo;
38
39struct RegisterMaskPair {
40  unsigned RegUnit; ///< Virtual register or register unit.
41  LaneBitmask LaneMask;
42
43  RegisterMaskPair(unsigned RegUnit, LaneBitmask LaneMask)
44      : RegUnit(RegUnit), LaneMask(LaneMask) {}
45};
46
47/// Base class for register pressure results.
48struct RegisterPressure {
49  /// Map of max reg pressure indexed by pressure set ID, not class ID.
50  std::vector<unsigned> MaxSetPressure;
51
52  /// List of live in virtual registers or physical register units.
53  SmallVector<RegisterMaskPair,8> LiveInRegs;
54  SmallVector<RegisterMaskPair,8> LiveOutRegs;
55
56  void dump(const TargetRegisterInfo *TRI) const;
57};
58
59/// RegisterPressure computed within a region of instructions delimited by
60/// TopIdx and BottomIdx.  During pressure computation, the maximum pressure per
61/// register pressure set is increased. Once pressure within a region is fully
62/// computed, the live-in and live-out sets are recorded.
63///
64/// This is preferable to RegionPressure when LiveIntervals are available,
65/// because delimiting regions by SlotIndex is more robust and convenient than
66/// holding block iterators. The block contents can change without invalidating
67/// the pressure result.
68struct IntervalPressure : RegisterPressure {
69  /// Record the boundary of the region being tracked.
70  SlotIndex TopIdx;
71  SlotIndex BottomIdx;
72
73  void reset();
74
75  void openTop(SlotIndex NextTop);
76
77  void openBottom(SlotIndex PrevBottom);
78};
79
80/// RegisterPressure computed within a region of instructions delimited by
81/// TopPos and BottomPos. This is a less precise version of IntervalPressure for
82/// use when LiveIntervals are unavailable.
83struct RegionPressure : RegisterPressure {
84  /// Record the boundary of the region being tracked.
85  MachineBasicBlock::const_iterator TopPos;
86  MachineBasicBlock::const_iterator BottomPos;
87
88  void reset();
89
90  void openTop(MachineBasicBlock::const_iterator PrevTop);
91
92  void openBottom(MachineBasicBlock::const_iterator PrevBottom);
93};
94
95/// Capture a change in pressure for a single pressure set. UnitInc may be
96/// expressed in terms of upward or downward pressure depending on the client
97/// and will be dynamically adjusted for current liveness.
98///
99/// Pressure increments are tiny, typically 1-2 units, and this is only for
100/// heuristics, so we don't check UnitInc overflow. Instead, we may have a
101/// higher level assert that pressure is consistent within a region. We also
102/// effectively ignore dead defs which don't affect heuristics much.
103class PressureChange {
104  uint16_t PSetID = 0; // ID+1. 0=Invalid.
105  int16_t UnitInc = 0;
106
107public:
108  PressureChange() = default;
109  PressureChange(unsigned id): PSetID(id + 1) {
110    assert(id < std::numeric_limits<uint16_t>::max() && "PSetID overflow.");
111  }
112
113  bool isValid() const { return PSetID > 0; }
114
115  unsigned getPSet() const {
116    assert(isValid() && "invalid PressureChange");
117    return PSetID - 1;
118  }
119
120  // If PSetID is invalid, return UINT16_MAX to give it lowest priority.
121  unsigned getPSetOrMax() const {
122    return (PSetID - 1) & std::numeric_limits<uint16_t>::max();
123  }
124
125  int getUnitInc() const { return UnitInc; }
126
127  void setUnitInc(int Inc) { UnitInc = Inc; }
128
129  bool operator==(const PressureChange &RHS) const {
130    return PSetID == RHS.PSetID && UnitInc == RHS.UnitInc;
131  }
132
133  void dump() const;
134};
135
136/// List of PressureChanges in order of increasing, unique PSetID.
137///
138/// Use a small fixed number, because we can fit more PressureChanges in an
139/// empty SmallVector than ever need to be tracked per register class. If more
140/// PSets are affected, then we only track the most constrained.
141class PressureDiff {
142  // The initial design was for MaxPSets=4, but that requires PSet partitions,
143  // which are not yet implemented. (PSet partitions are equivalent PSets given
144  // the register classes actually in use within the scheduling region.)
145  enum { MaxPSets = 16 };
146
147  PressureChange PressureChanges[MaxPSets];
148
149  using iterator = PressureChange *;
150
151  iterator nonconst_begin() { return &PressureChanges[0]; }
152  iterator nonconst_end() { return &PressureChanges[MaxPSets]; }
153
154public:
155  using const_iterator = const PressureChange *;
156
157  const_iterator begin() const { return &PressureChanges[0]; }
158  const_iterator end() const { return &PressureChanges[MaxPSets]; }
159
160  void addPressureChange(unsigned RegUnit, bool IsDec,
161                         const MachineRegisterInfo *MRI);
162
163  void dump(const TargetRegisterInfo &TRI) const;
164};
165
166/// List of registers defined and used by a machine instruction.
167class RegisterOperands {
168public:
169  /// List of virtual registers and register units read by the instruction.
170  SmallVector<RegisterMaskPair, 8> Uses;
171  /// List of virtual registers and register units defined by the
172  /// instruction which are not dead.
173  SmallVector<RegisterMaskPair, 8> Defs;
174  /// List of virtual registers and register units defined by the
175  /// instruction but dead.
176  SmallVector<RegisterMaskPair, 8> DeadDefs;
177
178  /// Analyze the given instruction \p MI and fill in the Uses, Defs and
179  /// DeadDefs list based on the MachineOperand flags.
180  void collect(const MachineInstr &MI, const TargetRegisterInfo &TRI,
181               const MachineRegisterInfo &MRI, bool TrackLaneMasks,
182               bool IgnoreDead);
183
184  /// Use liveness information to find dead defs not marked with a dead flag
185  /// and move them to the DeadDefs vector.
186  void detectDeadDefs(const MachineInstr &MI, const LiveIntervals &LIS);
187
188  /// Use liveness information to find out which uses/defs are partially
189  /// undefined/dead and adjust the RegisterMaskPairs accordingly.
190  /// If \p AddFlagsMI is given then missing read-undef and dead flags will be
191  /// added to the instruction.
192  void adjustLaneLiveness(const LiveIntervals &LIS,
193                          const MachineRegisterInfo &MRI, SlotIndex Pos,
194                          MachineInstr *AddFlagsMI = nullptr);
195};
196
197/// Array of PressureDiffs.
198class PressureDiffs {
199  PressureDiff *PDiffArray = nullptr;
200  unsigned Size = 0;
201  unsigned Max = 0;
202
203public:
204  PressureDiffs() = default;
205  ~PressureDiffs() { free(PDiffArray); }
206
207  void clear() { Size = 0; }
208
209  void init(unsigned N);
210
211  PressureDiff &operator[](unsigned Idx) {
212    assert(Idx < Size && "PressureDiff index out of bounds");
213    return PDiffArray[Idx];
214  }
215  const PressureDiff &operator[](unsigned Idx) const {
216    return const_cast<PressureDiffs*>(this)->operator[](Idx);
217  }
218
219  /// Record pressure difference induced by the given operand list to
220  /// node with index \p Idx.
221  void addInstruction(unsigned Idx, const RegisterOperands &RegOpers,
222                      const MachineRegisterInfo &MRI);
223};
224
225/// Store the effects of a change in pressure on things that MI scheduler cares
226/// about.
227///
228/// Excess records the value of the largest difference in register units beyond
229/// the target's pressure limits across the affected pressure sets, where
230/// largest is defined as the absolute value of the difference. Negative
231/// ExcessUnits indicates a reduction in pressure that had already exceeded the
232/// target's limits.
233///
234/// CriticalMax records the largest increase in the tracker's max pressure that
235/// exceeds the critical limit for some pressure set determined by the client.
236///
237/// CurrentMax records the largest increase in the tracker's max pressure that
238/// exceeds the current limit for some pressure set determined by the client.
239struct RegPressureDelta {
240  PressureChange Excess;
241  PressureChange CriticalMax;
242  PressureChange CurrentMax;
243
244  RegPressureDelta() = default;
245
246  bool operator==(const RegPressureDelta &RHS) const {
247    return Excess == RHS.Excess && CriticalMax == RHS.CriticalMax
248      && CurrentMax == RHS.CurrentMax;
249  }
250  bool operator!=(const RegPressureDelta &RHS) const {
251    return !operator==(RHS);
252  }
253  void dump() const;
254};
255
256/// A set of live virtual registers and physical register units.
257///
258/// This is a wrapper around a SparseSet which deals with mapping register unit
259/// and virtual register indexes to an index usable by the sparse set.
260class LiveRegSet {
261private:
262  struct IndexMaskPair {
263    unsigned Index;
264    LaneBitmask LaneMask;
265
266    IndexMaskPair(unsigned Index, LaneBitmask LaneMask)
267        : Index(Index), LaneMask(LaneMask) {}
268
269    unsigned getSparseSetIndex() const {
270      return Index;
271    }
272  };
273
274  using RegSet = SparseSet<IndexMaskPair>;
275  RegSet Regs;
276  unsigned NumRegUnits;
277
278  unsigned getSparseIndexFromReg(unsigned Reg) const {
279    if (Register::isVirtualRegister(Reg))
280      return Register::virtReg2Index(Reg) + NumRegUnits;
281    assert(Reg < NumRegUnits);
282    return Reg;
283  }
284
285  unsigned getRegFromSparseIndex(unsigned SparseIndex) const {
286    if (SparseIndex >= NumRegUnits)
287      return Register::index2VirtReg(SparseIndex-NumRegUnits);
288    return SparseIndex;
289  }
290
291public:
292  void clear();
293  void init(const MachineRegisterInfo &MRI);
294
295  LaneBitmask contains(unsigned Reg) const {
296    unsigned SparseIndex = getSparseIndexFromReg(Reg);
297    RegSet::const_iterator I = Regs.find(SparseIndex);
298    if (I == Regs.end())
299      return LaneBitmask::getNone();
300    return I->LaneMask;
301  }
302
303  /// Mark the \p Pair.LaneMask lanes of \p Pair.Reg as live.
304  /// Returns the previously live lanes of \p Pair.Reg.
305  LaneBitmask insert(RegisterMaskPair Pair) {
306    unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
307    auto InsertRes = Regs.insert(IndexMaskPair(SparseIndex, Pair.LaneMask));
308    if (!InsertRes.second) {
309      LaneBitmask PrevMask = InsertRes.first->LaneMask;
310      InsertRes.first->LaneMask |= Pair.LaneMask;
311      return PrevMask;
312    }
313    return LaneBitmask::getNone();
314  }
315
316  /// Clears the \p Pair.LaneMask lanes of \p Pair.Reg (mark them as dead).
317  /// Returns the previously live lanes of \p Pair.Reg.
318  LaneBitmask erase(RegisterMaskPair Pair) {
319    unsigned SparseIndex = getSparseIndexFromReg(Pair.RegUnit);
320    RegSet::iterator I = Regs.find(SparseIndex);
321    if (I == Regs.end())
322      return LaneBitmask::getNone();
323    LaneBitmask PrevMask = I->LaneMask;
324    I->LaneMask &= ~Pair.LaneMask;
325    return PrevMask;
326  }
327
328  size_t size() const {
329    return Regs.size();
330  }
331
332  template<typename ContainerT>
333  void appendTo(ContainerT &To) const {
334    for (const IndexMaskPair &P : Regs) {
335      unsigned Reg = getRegFromSparseIndex(P.Index);
336      if (P.LaneMask.any())
337        To.push_back(RegisterMaskPair(Reg, P.LaneMask));
338    }
339  }
340};
341
342/// Track the current register pressure at some position in the instruction
343/// stream, and remember the high water mark within the region traversed. This
344/// does not automatically consider live-through ranges. The client may
345/// independently adjust for global liveness.
346///
347/// Each RegPressureTracker only works within a MachineBasicBlock. Pressure can
348/// be tracked across a larger region by storing a RegisterPressure result at
349/// each block boundary and explicitly adjusting pressure to account for block
350/// live-in and live-out register sets.
351///
352/// RegPressureTracker holds a reference to a RegisterPressure result that it
353/// computes incrementally. During downward tracking, P.BottomIdx or P.BottomPos
354/// is invalid until it reaches the end of the block or closeRegion() is
355/// explicitly called. Similarly, P.TopIdx is invalid during upward
356/// tracking. Changing direction has the side effect of closing region, and
357/// traversing past TopIdx or BottomIdx reopens it.
358class RegPressureTracker {
359  const MachineFunction *MF = nullptr;
360  const TargetRegisterInfo *TRI = nullptr;
361  const RegisterClassInfo *RCI = nullptr;
362  const MachineRegisterInfo *MRI;
363  const LiveIntervals *LIS = nullptr;
364
365  /// We currently only allow pressure tracking within a block.
366  const MachineBasicBlock *MBB = nullptr;
367
368  /// Track the max pressure within the region traversed so far.
369  RegisterPressure &P;
370
371  /// Run in two modes dependending on whether constructed with IntervalPressure
372  /// or RegisterPressure. If requireIntervals is false, LIS are ignored.
373  bool RequireIntervals;
374
375  /// True if UntiedDefs will be populated.
376  bool TrackUntiedDefs = false;
377
378  /// True if lanemasks should be tracked.
379  bool TrackLaneMasks = false;
380
381  /// Register pressure corresponds to liveness before this instruction
382  /// iterator. It may point to the end of the block or a DebugValue rather than
383  /// an instruction.
384  MachineBasicBlock::const_iterator CurrPos;
385
386  /// Pressure map indexed by pressure set ID, not class ID.
387  std::vector<unsigned> CurrSetPressure;
388
389  /// Set of live registers.
390  LiveRegSet LiveRegs;
391
392  /// Set of vreg defs that start a live range.
393  SparseSet<unsigned, VirtReg2IndexFunctor> UntiedDefs;
394  /// Live-through pressure.
395  std::vector<unsigned> LiveThruPressure;
396
397public:
398  RegPressureTracker(IntervalPressure &rp) : P(rp), RequireIntervals(true) {}
399  RegPressureTracker(RegionPressure &rp) : P(rp), RequireIntervals(false) {}
400
401  void reset();
402
403  void init(const MachineFunction *mf, const RegisterClassInfo *rci,
404            const LiveIntervals *lis, const MachineBasicBlock *mbb,
405            MachineBasicBlock::const_iterator pos,
406            bool TrackLaneMasks, bool TrackUntiedDefs);
407
408  /// Force liveness of virtual registers or physical register
409  /// units. Particularly useful to initialize the livein/out state of the
410  /// tracker before the first call to advance/recede.
411  void addLiveRegs(ArrayRef<RegisterMaskPair> Regs);
412
413  /// Get the MI position corresponding to this register pressure.
414  MachineBasicBlock::const_iterator getPos() const { return CurrPos; }
415
416  // Reset the MI position corresponding to the register pressure. This allows
417  // schedulers to move instructions above the RegPressureTracker's
418  // CurrPos. Since the pressure is computed before CurrPos, the iterator
419  // position changes while pressure does not.
420  void setPos(MachineBasicBlock::const_iterator Pos) { CurrPos = Pos; }
421
422  /// Recede across the previous instruction.
423  void recede(SmallVectorImpl<RegisterMaskPair> *LiveUses = nullptr);
424
425  /// Recede across the previous instruction.
426  /// This "low-level" variant assumes that recedeSkipDebugValues() was
427  /// called previously and takes precomputed RegisterOperands for the
428  /// instruction.
429  void recede(const RegisterOperands &RegOpers,
430              SmallVectorImpl<RegisterMaskPair> *LiveUses = nullptr);
431
432  /// Recede until we find an instruction which is not a DebugValue.
433  void recedeSkipDebugValues();
434
435  /// Advance across the current instruction.
436  void advance();
437
438  /// Advance across the current instruction.
439  /// This is a "low-level" variant of advance() which takes precomputed
440  /// RegisterOperands of the instruction.
441  void advance(const RegisterOperands &RegOpers);
442
443  /// Finalize the region boundaries and recored live ins and live outs.
444  void closeRegion();
445
446  /// Initialize the LiveThru pressure set based on the untied defs found in
447  /// RPTracker.
448  void initLiveThru(const RegPressureTracker &RPTracker);
449
450  /// Copy an existing live thru pressure result.
451  void initLiveThru(ArrayRef<unsigned> PressureSet) {
452    LiveThruPressure.assign(PressureSet.begin(), PressureSet.end());
453  }
454
455  ArrayRef<unsigned> getLiveThru() const { return LiveThruPressure; }
456
457  /// Get the resulting register pressure over the traversed region.
458  /// This result is complete if closeRegion() was explicitly invoked.
459  RegisterPressure &getPressure() { return P; }
460  const RegisterPressure &getPressure() const { return P; }
461
462  /// Get the register set pressure at the current position, which may be less
463  /// than the pressure across the traversed region.
464  const std::vector<unsigned> &getRegSetPressureAtPos() const {
465    return CurrSetPressure;
466  }
467
468  bool isTopClosed() const;
469  bool isBottomClosed() const;
470
471  void closeTop();
472  void closeBottom();
473
474  /// Consider the pressure increase caused by traversing this instruction
475  /// bottom-up. Find the pressure set with the most change beyond its pressure
476  /// limit based on the tracker's current pressure, and record the number of
477  /// excess register units of that pressure set introduced by this instruction.
478  void getMaxUpwardPressureDelta(const MachineInstr *MI,
479                                 PressureDiff *PDiff,
480                                 RegPressureDelta &Delta,
481                                 ArrayRef<PressureChange> CriticalPSets,
482                                 ArrayRef<unsigned> MaxPressureLimit);
483
484  void getUpwardPressureDelta(const MachineInstr *MI,
485                              /*const*/ PressureDiff &PDiff,
486                              RegPressureDelta &Delta,
487                              ArrayRef<PressureChange> CriticalPSets,
488                              ArrayRef<unsigned> MaxPressureLimit) const;
489
490  /// Consider the pressure increase caused by traversing this instruction
491  /// top-down. Find the pressure set with the most change beyond its pressure
492  /// limit based on the tracker's current pressure, and record the number of
493  /// excess register units of that pressure set introduced by this instruction.
494  void getMaxDownwardPressureDelta(const MachineInstr *MI,
495                                   RegPressureDelta &Delta,
496                                   ArrayRef<PressureChange> CriticalPSets,
497                                   ArrayRef<unsigned> MaxPressureLimit);
498
499  /// Find the pressure set with the most change beyond its pressure limit after
500  /// traversing this instruction either upward or downward depending on the
501  /// closed end of the current region.
502  void getMaxPressureDelta(const MachineInstr *MI,
503                           RegPressureDelta &Delta,
504                           ArrayRef<PressureChange> CriticalPSets,
505                           ArrayRef<unsigned> MaxPressureLimit) {
506    if (isTopClosed())
507      return getMaxDownwardPressureDelta(MI, Delta, CriticalPSets,
508                                         MaxPressureLimit);
509
510    assert(isBottomClosed() && "Uninitialized pressure tracker");
511    return getMaxUpwardPressureDelta(MI, nullptr, Delta, CriticalPSets,
512                                     MaxPressureLimit);
513  }
514
515  /// Get the pressure of each PSet after traversing this instruction bottom-up.
516  void getUpwardPressure(const MachineInstr *MI,
517                         std::vector<unsigned> &PressureResult,
518                         std::vector<unsigned> &MaxPressureResult);
519
520  /// Get the pressure of each PSet after traversing this instruction top-down.
521  void getDownwardPressure(const MachineInstr *MI,
522                           std::vector<unsigned> &PressureResult,
523                           std::vector<unsigned> &MaxPressureResult);
524
525  void getPressureAfterInst(const MachineInstr *MI,
526                            std::vector<unsigned> &PressureResult,
527                            std::vector<unsigned> &MaxPressureResult) {
528    if (isTopClosed())
529      return getUpwardPressure(MI, PressureResult, MaxPressureResult);
530
531    assert(isBottomClosed() && "Uninitialized pressure tracker");
532    return getDownwardPressure(MI, PressureResult, MaxPressureResult);
533  }
534
535  bool hasUntiedDef(unsigned VirtReg) const {
536    return UntiedDefs.count(VirtReg);
537  }
538
539  void dump() const;
540
541protected:
542  /// Add Reg to the live out set and increase max pressure.
543  void discoverLiveOut(RegisterMaskPair Pair);
544  /// Add Reg to the live in set and increase max pressure.
545  void discoverLiveIn(RegisterMaskPair Pair);
546
547  /// Get the SlotIndex for the first nondebug instruction including or
548  /// after the current position.
549  SlotIndex getCurrSlot() const;
550
551  void increaseRegPressure(unsigned RegUnit, LaneBitmask PreviousMask,
552                           LaneBitmask NewMask);
553  void decreaseRegPressure(unsigned RegUnit, LaneBitmask PreviousMask,
554                           LaneBitmask NewMask);
555
556  void bumpDeadDefs(ArrayRef<RegisterMaskPair> DeadDefs);
557
558  void bumpUpwardPressure(const MachineInstr *MI);
559  void bumpDownwardPressure(const MachineInstr *MI);
560
561  void discoverLiveInOrOut(RegisterMaskPair Pair,
562                           SmallVectorImpl<RegisterMaskPair> &LiveInOrOut);
563
564  LaneBitmask getLastUsedLanes(unsigned RegUnit, SlotIndex Pos) const;
565  LaneBitmask getLiveLanesAt(unsigned RegUnit, SlotIndex Pos) const;
566  LaneBitmask getLiveThroughAt(unsigned RegUnit, SlotIndex Pos) const;
567};
568
569void dumpRegSetPressure(ArrayRef<unsigned> SetPressure,
570                        const TargetRegisterInfo *TRI);
571
572} // end namespace llvm
573
574#endif // LLVM_CODEGEN_REGISTERPRESSURE_H
575