1//===- GCNRegPressure.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "GCNRegPressure.h"
10#include "AMDGPUSubtarget.h"
11#include "SIRegisterInfo.h"
12#include "llvm/ADT/SmallVector.h"
13#include "llvm/CodeGen/LiveInterval.h"
14#include "llvm/CodeGen/LiveIntervals.h"
15#include "llvm/CodeGen/MachineInstr.h"
16#include "llvm/CodeGen/MachineOperand.h"
17#include "llvm/CodeGen/MachineRegisterInfo.h"
18#include "llvm/CodeGen/RegisterPressure.h"
19#include "llvm/CodeGen/SlotIndexes.h"
20#include "llvm/CodeGen/TargetRegisterInfo.h"
21#include "llvm/Config/llvm-config.h"
22#include "llvm/MC/LaneBitmask.h"
23#include "llvm/Support/Compiler.h"
24#include "llvm/Support/Debug.h"
25#include "llvm/Support/ErrorHandling.h"
26#include "llvm/Support/raw_ostream.h"
27#include <algorithm>
28#include <cassert>
29
30using namespace llvm;
31
32#define DEBUG_TYPE "machine-scheduler"
33
34#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
35LLVM_DUMP_METHOD
36void llvm::printLivesAt(SlotIndex SI,
37                        const LiveIntervals &LIS,
38                        const MachineRegisterInfo &MRI) {
39  dbgs() << "Live regs at " << SI << ": "
40         << *LIS.getInstructionFromIndex(SI);
41  unsigned Num = 0;
42  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
43    const unsigned Reg = Register::index2VirtReg(I);
44    if (!LIS.hasInterval(Reg))
45      continue;
46    const auto &LI = LIS.getInterval(Reg);
47    if (LI.hasSubRanges()) {
48      bool firstTime = true;
49      for (const auto &S : LI.subranges()) {
50        if (!S.liveAt(SI)) continue;
51        if (firstTime) {
52          dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
53                 << '\n';
54          firstTime = false;
55        }
56        dbgs() << "  " << S << '\n';
57        ++Num;
58      }
59    } else if (LI.liveAt(SI)) {
60      dbgs() << "  " << LI << '\n';
61      ++Num;
62    }
63  }
64  if (!Num) dbgs() << "  <none>\n";
65}
66#endif
67
68bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
69                   const GCNRPTracker::LiveRegSet &S2) {
70  if (S1.size() != S2.size())
71    return false;
72
73  for (const auto &P : S1) {
74    auto I = S2.find(P.first);
75    if (I == S2.end() || I->second != P.second)
76      return false;
77  }
78  return true;
79}
80
81
82///////////////////////////////////////////////////////////////////////////////
83// GCNRegPressure
84
85unsigned GCNRegPressure::getRegKind(unsigned Reg,
86                                    const MachineRegisterInfo &MRI) {
87  assert(Register::isVirtualRegister(Reg));
88  const auto RC = MRI.getRegClass(Reg);
89  auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
90  return STI->isSGPRClass(RC) ?
91    (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
92    STI->hasAGPRs(RC) ?
93      (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
94      (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
95}
96
97void GCNRegPressure::inc(unsigned Reg,
98                         LaneBitmask PrevMask,
99                         LaneBitmask NewMask,
100                         const MachineRegisterInfo &MRI) {
101  if (NewMask == PrevMask)
102    return;
103
104  int Sign = 1;
105  if (NewMask < PrevMask) {
106    std::swap(NewMask, PrevMask);
107    Sign = -1;
108  }
109#ifndef NDEBUG
110  const auto MaxMask = MRI.getMaxLaneMaskForVReg(Reg);
111#endif
112  switch (auto Kind = getRegKind(Reg, MRI)) {
113  case SGPR32:
114  case VGPR32:
115  case AGPR32:
116    assert(PrevMask.none() && NewMask == MaxMask);
117    Value[Kind] += Sign;
118    break;
119
120  case SGPR_TUPLE:
121  case VGPR_TUPLE:
122  case AGPR_TUPLE:
123    assert(NewMask < MaxMask || NewMask == MaxMask);
124    assert(PrevMask < NewMask);
125
126    Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
127      Sign * (~PrevMask & NewMask).getNumLanes();
128
129    if (PrevMask.none()) {
130      assert(NewMask.any());
131      Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
132    }
133    break;
134
135  default: llvm_unreachable("Unknown register kind");
136  }
137}
138
139bool GCNRegPressure::less(const GCNSubtarget &ST,
140                          const GCNRegPressure& O,
141                          unsigned MaxOccupancy) const {
142  const auto SGPROcc = std::min(MaxOccupancy,
143                                ST.getOccupancyWithNumSGPRs(getSGPRNum()));
144  const auto VGPROcc = std::min(MaxOccupancy,
145                                ST.getOccupancyWithNumVGPRs(getVGPRNum()));
146  const auto OtherSGPROcc = std::min(MaxOccupancy,
147                                ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
148  const auto OtherVGPROcc = std::min(MaxOccupancy,
149                                ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
150
151  const auto Occ = std::min(SGPROcc, VGPROcc);
152  const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
153  if (Occ != OtherOcc)
154    return Occ > OtherOcc;
155
156  bool SGPRImportant = SGPROcc < VGPROcc;
157  const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
158
159  // if both pressures disagree on what is more important compare vgprs
160  if (SGPRImportant != OtherSGPRImportant) {
161    SGPRImportant = false;
162  }
163
164  // compare large regs pressure
165  bool SGPRFirst = SGPRImportant;
166  for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
167    if (SGPRFirst) {
168      auto SW = getSGPRTuplesWeight();
169      auto OtherSW = O.getSGPRTuplesWeight();
170      if (SW != OtherSW)
171        return SW < OtherSW;
172    } else {
173      auto VW = getVGPRTuplesWeight();
174      auto OtherVW = O.getVGPRTuplesWeight();
175      if (VW != OtherVW)
176        return VW < OtherVW;
177    }
178  }
179  return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
180                         (getVGPRNum() < O.getVGPRNum());
181}
182
183#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
184LLVM_DUMP_METHOD
185void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
186  OS << "VGPRs: " << Value[VGPR32] << ' ';
187  OS << "AGPRs: " << Value[AGPR32];
188  if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
189  OS << ", SGPRs: " << getSGPRNum();
190  if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
191  OS << ", LVGPR WT: " << getVGPRTuplesWeight()
192     << ", LSGPR WT: " << getSGPRTuplesWeight();
193  if (ST) OS << " -> Occ: " << getOccupancy(*ST);
194  OS << '\n';
195}
196#endif
197
198static LaneBitmask getDefRegMask(const MachineOperand &MO,
199                                 const MachineRegisterInfo &MRI) {
200  assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
201
202  // We don't rely on read-undef flag because in case of tentative schedule
203  // tracking it isn't set correctly yet. This works correctly however since
204  // use mask has been tracked before using LIS.
205  return MO.getSubReg() == 0 ?
206    MRI.getMaxLaneMaskForVReg(MO.getReg()) :
207    MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
208}
209
210static LaneBitmask getUsedRegMask(const MachineOperand &MO,
211                                  const MachineRegisterInfo &MRI,
212                                  const LiveIntervals &LIS) {
213  assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
214
215  if (auto SubReg = MO.getSubReg())
216    return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
217
218  auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
219  if (MaxMask == LaneBitmask::getLane(0)) // cannot have subregs
220    return MaxMask;
221
222  // For a tentative schedule LIS isn't updated yet but livemask should remain
223  // the same on any schedule. Subreg defs can be reordered but they all must
224  // dominate uses anyway.
225  auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
226  return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
227}
228
229static SmallVector<RegisterMaskPair, 8>
230collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
231                      const MachineRegisterInfo &MRI) {
232  SmallVector<RegisterMaskPair, 8> Res;
233  for (const auto &MO : MI.operands()) {
234    if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
235      continue;
236    if (!MO.isUse() || !MO.readsReg())
237      continue;
238
239    auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
240
241    auto Reg = MO.getReg();
242    auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
243      return RM.RegUnit == Reg;
244    });
245    if (I != Res.end())
246      I->LaneMask |= UsedMask;
247    else
248      Res.push_back(RegisterMaskPair(Reg, UsedMask));
249  }
250  return Res;
251}
252
253///////////////////////////////////////////////////////////////////////////////
254// GCNRPTracker
255
256LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
257                                  SlotIndex SI,
258                                  const LiveIntervals &LIS,
259                                  const MachineRegisterInfo &MRI) {
260  LaneBitmask LiveMask;
261  const auto &LI = LIS.getInterval(Reg);
262  if (LI.hasSubRanges()) {
263    for (const auto &S : LI.subranges())
264      if (S.liveAt(SI)) {
265        LiveMask |= S.LaneMask;
266        assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
267               LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
268      }
269  } else if (LI.liveAt(SI)) {
270    LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
271  }
272  return LiveMask;
273}
274
275GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
276                                           const LiveIntervals &LIS,
277                                           const MachineRegisterInfo &MRI) {
278  GCNRPTracker::LiveRegSet LiveRegs;
279  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
280    auto Reg = Register::index2VirtReg(I);
281    if (!LIS.hasInterval(Reg))
282      continue;
283    auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
284    if (LiveMask.any())
285      LiveRegs[Reg] = LiveMask;
286  }
287  return LiveRegs;
288}
289
290void GCNRPTracker::reset(const MachineInstr &MI,
291                         const LiveRegSet *LiveRegsCopy,
292                         bool After) {
293  const MachineFunction &MF = *MI.getMF();
294  MRI = &MF.getRegInfo();
295  if (LiveRegsCopy) {
296    if (&LiveRegs != LiveRegsCopy)
297      LiveRegs = *LiveRegsCopy;
298  } else {
299    LiveRegs = After ? getLiveRegsAfter(MI, LIS)
300                     : getLiveRegsBefore(MI, LIS);
301  }
302
303  MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
304}
305
306void GCNUpwardRPTracker::reset(const MachineInstr &MI,
307                               const LiveRegSet *LiveRegsCopy) {
308  GCNRPTracker::reset(MI, LiveRegsCopy, true);
309}
310
311void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
312  assert(MRI && "call reset first");
313
314  LastTrackedMI = &MI;
315
316  if (MI.isDebugInstr())
317    return;
318
319  auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
320
321  // calc pressure at the MI (defs + uses)
322  auto AtMIPressure = CurPressure;
323  for (const auto &U : RegUses) {
324    auto LiveMask = LiveRegs[U.RegUnit];
325    AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
326  }
327  // update max pressure
328  MaxPressure = max(AtMIPressure, MaxPressure);
329
330  for (const auto &MO : MI.defs()) {
331    if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
332      continue;
333
334    auto Reg = MO.getReg();
335    auto I = LiveRegs.find(Reg);
336    if (I == LiveRegs.end())
337      continue;
338    auto &LiveMask = I->second;
339    auto PrevMask = LiveMask;
340    LiveMask &= ~getDefRegMask(MO, *MRI);
341    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
342    if (LiveMask.none())
343      LiveRegs.erase(I);
344  }
345  for (const auto &U : RegUses) {
346    auto &LiveMask = LiveRegs[U.RegUnit];
347    auto PrevMask = LiveMask;
348    LiveMask |= U.LaneMask;
349    CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
350  }
351  assert(CurPressure == getRegPressure(*MRI, LiveRegs));
352}
353
354bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
355                                 const LiveRegSet *LiveRegsCopy) {
356  MRI = &MI.getParent()->getParent()->getRegInfo();
357  LastTrackedMI = nullptr;
358  MBBEnd = MI.getParent()->end();
359  NextMI = &MI;
360  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
361  if (NextMI == MBBEnd)
362    return false;
363  GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
364  return true;
365}
366
367bool GCNDownwardRPTracker::advanceBeforeNext() {
368  assert(MRI && "call reset first");
369
370  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
371  if (NextMI == MBBEnd)
372    return false;
373
374  SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
375  assert(SI.isValid());
376
377  // Remove dead registers or mask bits.
378  for (auto &It : LiveRegs) {
379    const LiveInterval &LI = LIS.getInterval(It.first);
380    if (LI.hasSubRanges()) {
381      for (const auto &S : LI.subranges()) {
382        if (!S.liveAt(SI)) {
383          auto PrevMask = It.second;
384          It.second &= ~S.LaneMask;
385          CurPressure.inc(It.first, PrevMask, It.second, *MRI);
386        }
387      }
388    } else if (!LI.liveAt(SI)) {
389      auto PrevMask = It.second;
390      It.second = LaneBitmask::getNone();
391      CurPressure.inc(It.first, PrevMask, It.second, *MRI);
392    }
393    if (It.second.none())
394      LiveRegs.erase(It.first);
395  }
396
397  MaxPressure = max(MaxPressure, CurPressure);
398
399  return true;
400}
401
402void GCNDownwardRPTracker::advanceToNext() {
403  LastTrackedMI = &*NextMI++;
404
405  // Add new registers or mask bits.
406  for (const auto &MO : LastTrackedMI->defs()) {
407    if (!MO.isReg())
408      continue;
409    Register Reg = MO.getReg();
410    if (!Register::isVirtualRegister(Reg))
411      continue;
412    auto &LiveMask = LiveRegs[Reg];
413    auto PrevMask = LiveMask;
414    LiveMask |= getDefRegMask(MO, *MRI);
415    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
416  }
417
418  MaxPressure = max(MaxPressure, CurPressure);
419}
420
421bool GCNDownwardRPTracker::advance() {
422  // If we have just called reset live set is actual.
423  if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
424    return false;
425  advanceToNext();
426  return true;
427}
428
429bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
430  while (NextMI != End)
431    if (!advance()) return false;
432  return true;
433}
434
435bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
436                                   MachineBasicBlock::const_iterator End,
437                                   const LiveRegSet *LiveRegsCopy) {
438  reset(*Begin, LiveRegsCopy);
439  return advance(End);
440}
441
442#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
443LLVM_DUMP_METHOD
444static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
445                           const GCNRPTracker::LiveRegSet &TrackedLR,
446                           const TargetRegisterInfo *TRI) {
447  for (auto const &P : TrackedLR) {
448    auto I = LISLR.find(P.first);
449    if (I == LISLR.end()) {
450      dbgs() << "  " << printReg(P.first, TRI)
451             << ":L" << PrintLaneMask(P.second)
452             << " isn't found in LIS reported set\n";
453    }
454    else if (I->second != P.second) {
455      dbgs() << "  " << printReg(P.first, TRI)
456        << " masks doesn't match: LIS reported "
457        << PrintLaneMask(I->second)
458        << ", tracked "
459        << PrintLaneMask(P.second)
460        << '\n';
461    }
462  }
463  for (auto const &P : LISLR) {
464    auto I = TrackedLR.find(P.first);
465    if (I == TrackedLR.end()) {
466      dbgs() << "  " << printReg(P.first, TRI)
467             << ":L" << PrintLaneMask(P.second)
468             << " isn't found in tracked set\n";
469    }
470  }
471}
472
473bool GCNUpwardRPTracker::isValid() const {
474  const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
475  const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
476  const auto &TrackedLR = LiveRegs;
477
478  if (!isEqual(LISLR, TrackedLR)) {
479    dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
480              " LIS reported livesets mismatch:\n";
481    printLivesAt(SI, LIS, *MRI);
482    reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
483    return false;
484  }
485
486  auto LISPressure = getRegPressure(*MRI, LISLR);
487  if (LISPressure != CurPressure) {
488    dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
489    CurPressure.print(dbgs());
490    dbgs() << "LIS rpt: ";
491    LISPressure.print(dbgs());
492    return false;
493  }
494  return true;
495}
496
497void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
498                                 const MachineRegisterInfo &MRI) {
499  const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
500  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
501    unsigned Reg = Register::index2VirtReg(I);
502    auto It = LiveRegs.find(Reg);
503    if (It != LiveRegs.end() && It->second.any())
504      OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
505         << PrintLaneMask(It->second);
506  }
507  OS << '\n';
508}
509#endif
510