1//===- GCNRegPressure.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the GCNRegPressure class.
11///
12//===----------------------------------------------------------------------===//
13
14#include "GCNRegPressure.h"
15#include "llvm/CodeGen/RegisterPressure.h"
16
17using namespace llvm;
18
19#define DEBUG_TYPE "machine-scheduler"
20
21bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
22                   const GCNRPTracker::LiveRegSet &S2) {
23  if (S1.size() != S2.size())
24    return false;
25
26  for (const auto &P : S1) {
27    auto I = S2.find(P.first);
28    if (I == S2.end() || I->second != P.second)
29      return false;
30  }
31  return true;
32}
33
34
35///////////////////////////////////////////////////////////////////////////////
36// GCNRegPressure
37
38unsigned GCNRegPressure::getRegKind(Register Reg,
39                                    const MachineRegisterInfo &MRI) {
40  assert(Reg.isVirtual());
41  const auto RC = MRI.getRegClass(Reg);
42  auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
43  return STI->isSGPRClass(RC)
44             ? (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE)
45         : STI->isAGPRClass(RC)
46             ? (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE)
47             : (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
48}
49
50void GCNRegPressure::inc(unsigned Reg,
51                         LaneBitmask PrevMask,
52                         LaneBitmask NewMask,
53                         const MachineRegisterInfo &MRI) {
54  if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
55      SIRegisterInfo::getNumCoveredRegs(PrevMask))
56    return;
57
58  int Sign = 1;
59  if (NewMask < PrevMask) {
60    std::swap(NewMask, PrevMask);
61    Sign = -1;
62  }
63
64  switch (auto Kind = getRegKind(Reg, MRI)) {
65  case SGPR32:
66  case VGPR32:
67  case AGPR32:
68    Value[Kind] += Sign;
69    break;
70
71  case SGPR_TUPLE:
72  case VGPR_TUPLE:
73  case AGPR_TUPLE:
74    assert(PrevMask < NewMask);
75
76    Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
77      Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
78
79    if (PrevMask.none()) {
80      assert(NewMask.any());
81      Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
82    }
83    break;
84
85  default: llvm_unreachable("Unknown register kind");
86  }
87}
88
89bool GCNRegPressure::less(const GCNSubtarget &ST,
90                          const GCNRegPressure& O,
91                          unsigned MaxOccupancy) const {
92  const auto SGPROcc = std::min(MaxOccupancy,
93                                ST.getOccupancyWithNumSGPRs(getSGPRNum()));
94  const auto VGPROcc =
95    std::min(MaxOccupancy,
96             ST.getOccupancyWithNumVGPRs(getVGPRNum(ST.hasGFX90AInsts())));
97  const auto OtherSGPROcc = std::min(MaxOccupancy,
98                                ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
99  const auto OtherVGPROcc =
100    std::min(MaxOccupancy,
101             ST.getOccupancyWithNumVGPRs(O.getVGPRNum(ST.hasGFX90AInsts())));
102
103  const auto Occ = std::min(SGPROcc, VGPROcc);
104  const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
105  if (Occ != OtherOcc)
106    return Occ > OtherOcc;
107
108  bool SGPRImportant = SGPROcc < VGPROcc;
109  const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
110
111  // if both pressures disagree on what is more important compare vgprs
112  if (SGPRImportant != OtherSGPRImportant) {
113    SGPRImportant = false;
114  }
115
116  // compare large regs pressure
117  bool SGPRFirst = SGPRImportant;
118  for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
119    if (SGPRFirst) {
120      auto SW = getSGPRTuplesWeight();
121      auto OtherSW = O.getSGPRTuplesWeight();
122      if (SW != OtherSW)
123        return SW < OtherSW;
124    } else {
125      auto VW = getVGPRTuplesWeight();
126      auto OtherVW = O.getVGPRTuplesWeight();
127      if (VW != OtherVW)
128        return VW < OtherVW;
129    }
130  }
131  return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
132                         (getVGPRNum(ST.hasGFX90AInsts()) <
133                          O.getVGPRNum(ST.hasGFX90AInsts()));
134}
135
136#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
137LLVM_DUMP_METHOD
138Printable llvm::print(const GCNRegPressure &RP, const GCNSubtarget *ST) {
139  return Printable([&RP, ST](raw_ostream &OS) {
140    OS << "VGPRs: " << RP.Value[GCNRegPressure::VGPR32] << ' '
141       << "AGPRs: " << RP.getAGPRNum();
142    if (ST)
143      OS << "(O"
144         << ST->getOccupancyWithNumVGPRs(RP.getVGPRNum(ST->hasGFX90AInsts()))
145         << ')';
146    OS << ", SGPRs: " << RP.getSGPRNum();
147    if (ST)
148      OS << "(O" << ST->getOccupancyWithNumSGPRs(RP.getSGPRNum()) << ')';
149    OS << ", LVGPR WT: " << RP.getVGPRTuplesWeight()
150       << ", LSGPR WT: " << RP.getSGPRTuplesWeight();
151    if (ST)
152      OS << " -> Occ: " << RP.getOccupancy(*ST);
153    OS << '\n';
154  });
155}
156#endif
157
158static LaneBitmask getDefRegMask(const MachineOperand &MO,
159                                 const MachineRegisterInfo &MRI) {
160  assert(MO.isDef() && MO.isReg() && MO.getReg().isVirtual());
161
162  // We don't rely on read-undef flag because in case of tentative schedule
163  // tracking it isn't set correctly yet. This works correctly however since
164  // use mask has been tracked before using LIS.
165  return MO.getSubReg() == 0 ?
166    MRI.getMaxLaneMaskForVReg(MO.getReg()) :
167    MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
168}
169
170static LaneBitmask getUsedRegMask(const MachineOperand &MO,
171                                  const MachineRegisterInfo &MRI,
172                                  const LiveIntervals &LIS) {
173  assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
174
175  if (auto SubReg = MO.getSubReg())
176    return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
177
178  auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
179  if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
180    return MaxMask;
181
182  // For a tentative schedule LIS isn't updated yet but livemask should remain
183  // the same on any schedule. Subreg defs can be reordered but they all must
184  // dominate uses anyway.
185  auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
186  return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
187}
188
189static SmallVector<RegisterMaskPair, 8>
190collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
191                      const MachineRegisterInfo &MRI) {
192  SmallVector<RegisterMaskPair, 8> Res;
193  for (const auto &MO : MI.operands()) {
194    if (!MO.isReg() || !MO.getReg().isVirtual())
195      continue;
196    if (!MO.isUse() || !MO.readsReg())
197      continue;
198
199    auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
200
201    auto Reg = MO.getReg();
202    auto I = llvm::find_if(
203        Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
204    if (I != Res.end())
205      I->LaneMask |= UsedMask;
206    else
207      Res.push_back(RegisterMaskPair(Reg, UsedMask));
208  }
209  return Res;
210}
211
212///////////////////////////////////////////////////////////////////////////////
213// GCNRPTracker
214
215LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
216                                  SlotIndex SI,
217                                  const LiveIntervals &LIS,
218                                  const MachineRegisterInfo &MRI) {
219  LaneBitmask LiveMask;
220  const auto &LI = LIS.getInterval(Reg);
221  if (LI.hasSubRanges()) {
222    for (const auto &S : LI.subranges())
223      if (S.liveAt(SI)) {
224        LiveMask |= S.LaneMask;
225        assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
226               LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
227      }
228  } else if (LI.liveAt(SI)) {
229    LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
230  }
231  return LiveMask;
232}
233
234GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
235                                           const LiveIntervals &LIS,
236                                           const MachineRegisterInfo &MRI) {
237  GCNRPTracker::LiveRegSet LiveRegs;
238  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
239    auto Reg = Register::index2VirtReg(I);
240    if (!LIS.hasInterval(Reg))
241      continue;
242    auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
243    if (LiveMask.any())
244      LiveRegs[Reg] = LiveMask;
245  }
246  return LiveRegs;
247}
248
249void GCNRPTracker::reset(const MachineInstr &MI,
250                         const LiveRegSet *LiveRegsCopy,
251                         bool After) {
252  const MachineFunction &MF = *MI.getMF();
253  MRI = &MF.getRegInfo();
254  if (LiveRegsCopy) {
255    if (&LiveRegs != LiveRegsCopy)
256      LiveRegs = *LiveRegsCopy;
257  } else {
258    LiveRegs = After ? getLiveRegsAfter(MI, LIS)
259                     : getLiveRegsBefore(MI, LIS);
260  }
261
262  MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
263}
264
265void GCNUpwardRPTracker::reset(const MachineInstr &MI,
266                               const LiveRegSet *LiveRegsCopy) {
267  GCNRPTracker::reset(MI, LiveRegsCopy, true);
268}
269
270void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
271  assert(MRI && "call reset first");
272
273  LastTrackedMI = &MI;
274
275  if (MI.isDebugInstr())
276    return;
277
278  auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
279
280  // calc pressure at the MI (defs + uses)
281  auto AtMIPressure = CurPressure;
282  for (const auto &U : RegUses) {
283    auto LiveMask = LiveRegs[U.RegUnit];
284    AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
285  }
286  // update max pressure
287  MaxPressure = max(AtMIPressure, MaxPressure);
288
289  for (const auto &MO : MI.operands()) {
290    if (!MO.isReg() || !MO.isDef() || !MO.getReg().isVirtual() || MO.isDead())
291      continue;
292
293    auto Reg = MO.getReg();
294    auto I = LiveRegs.find(Reg);
295    if (I == LiveRegs.end())
296      continue;
297    auto &LiveMask = I->second;
298    auto PrevMask = LiveMask;
299    LiveMask &= ~getDefRegMask(MO, *MRI);
300    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
301    if (LiveMask.none())
302      LiveRegs.erase(I);
303  }
304  for (const auto &U : RegUses) {
305    auto &LiveMask = LiveRegs[U.RegUnit];
306    auto PrevMask = LiveMask;
307    LiveMask |= U.LaneMask;
308    CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
309  }
310  assert(CurPressure == getRegPressure(*MRI, LiveRegs));
311}
312
313bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
314                                 const LiveRegSet *LiveRegsCopy) {
315  MRI = &MI.getParent()->getParent()->getRegInfo();
316  LastTrackedMI = nullptr;
317  MBBEnd = MI.getParent()->end();
318  NextMI = &MI;
319  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
320  if (NextMI == MBBEnd)
321    return false;
322  GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
323  return true;
324}
325
326bool GCNDownwardRPTracker::advanceBeforeNext() {
327  assert(MRI && "call reset first");
328  if (!LastTrackedMI)
329    return NextMI == MBBEnd;
330
331  assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
332
333  SlotIndex SI = NextMI == MBBEnd
334                     ? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
335                     : LIS.getInstructionIndex(*NextMI).getBaseIndex();
336  assert(SI.isValid());
337
338  // Remove dead registers or mask bits.
339  for (auto &It : LiveRegs) {
340    const LiveInterval &LI = LIS.getInterval(It.first);
341    if (LI.hasSubRanges()) {
342      for (const auto &S : LI.subranges()) {
343        if (!S.liveAt(SI)) {
344          auto PrevMask = It.second;
345          It.second &= ~S.LaneMask;
346          CurPressure.inc(It.first, PrevMask, It.second, *MRI);
347        }
348      }
349    } else if (!LI.liveAt(SI)) {
350      auto PrevMask = It.second;
351      It.second = LaneBitmask::getNone();
352      CurPressure.inc(It.first, PrevMask, It.second, *MRI);
353    }
354    if (It.second.none())
355      LiveRegs.erase(It.first);
356  }
357
358  MaxPressure = max(MaxPressure, CurPressure);
359
360  LastTrackedMI = nullptr;
361
362  return NextMI == MBBEnd;
363}
364
365void GCNDownwardRPTracker::advanceToNext() {
366  LastTrackedMI = &*NextMI++;
367  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
368
369  // Add new registers or mask bits.
370  for (const auto &MO : LastTrackedMI->operands()) {
371    if (!MO.isReg() || !MO.isDef())
372      continue;
373    Register Reg = MO.getReg();
374    if (!Reg.isVirtual())
375      continue;
376    auto &LiveMask = LiveRegs[Reg];
377    auto PrevMask = LiveMask;
378    LiveMask |= getDefRegMask(MO, *MRI);
379    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
380  }
381
382  MaxPressure = max(MaxPressure, CurPressure);
383}
384
385bool GCNDownwardRPTracker::advance() {
386  if (NextMI == MBBEnd)
387    return false;
388  advanceBeforeNext();
389  advanceToNext();
390  return true;
391}
392
393bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
394  while (NextMI != End)
395    if (!advance()) return false;
396  return true;
397}
398
399bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
400                                   MachineBasicBlock::const_iterator End,
401                                   const LiveRegSet *LiveRegsCopy) {
402  reset(*Begin, LiveRegsCopy);
403  return advance(End);
404}
405
406#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
407LLVM_DUMP_METHOD
408Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
409                               const GCNRPTracker::LiveRegSet &TrackedLR,
410                               const TargetRegisterInfo *TRI) {
411  return Printable([&LISLR, &TrackedLR, TRI](raw_ostream &OS) {
412    for (auto const &P : TrackedLR) {
413      auto I = LISLR.find(P.first);
414      if (I == LISLR.end()) {
415        OS << "  " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
416           << " isn't found in LIS reported set\n";
417      } else if (I->second != P.second) {
418        OS << "  " << printReg(P.first, TRI)
419           << " masks doesn't match: LIS reported " << PrintLaneMask(I->second)
420           << ", tracked " << PrintLaneMask(P.second) << '\n';
421      }
422    }
423    for (auto const &P : LISLR) {
424      auto I = TrackedLR.find(P.first);
425      if (I == TrackedLR.end()) {
426        OS << "  " << printReg(P.first, TRI) << ":L" << PrintLaneMask(P.second)
427           << " isn't found in tracked set\n";
428      }
429    }
430  });
431}
432
433bool GCNUpwardRPTracker::isValid() const {
434  const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
435  const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
436  const auto &TrackedLR = LiveRegs;
437
438  if (!isEqual(LISLR, TrackedLR)) {
439    dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
440              " LIS reported livesets mismatch:\n"
441           << print(LISLR, *MRI);
442    reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
443    return false;
444  }
445
446  auto LISPressure = getRegPressure(*MRI, LISLR);
447  if (LISPressure != CurPressure) {
448    dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: "
449           << print(CurPressure) << "LIS rpt: " << print(LISPressure);
450    return false;
451  }
452  return true;
453}
454
455LLVM_DUMP_METHOD
456Printable llvm::print(const GCNRPTracker::LiveRegSet &LiveRegs,
457                      const MachineRegisterInfo &MRI) {
458  return Printable([&LiveRegs, &MRI](raw_ostream &OS) {
459    const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
460    for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
461      Register Reg = Register::index2VirtReg(I);
462      auto It = LiveRegs.find(Reg);
463      if (It != LiveRegs.end() && It->second.any())
464        OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
465           << PrintLaneMask(It->second);
466    }
467    OS << '\n';
468  });
469}
470
471LLVM_DUMP_METHOD
472void GCNRegPressure::dump() const { dbgs() << print(*this); }
473
474#endif
475