1//===- GCNRegPressure.cpp -------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8///
9/// \file
10/// This file implements the GCNRegPressure class.
11///
12//===----------------------------------------------------------------------===//
13
14#include "GCNRegPressure.h"
15#include "AMDGPUSubtarget.h"
16#include "SIRegisterInfo.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/CodeGen/LiveInterval.h"
19#include "llvm/CodeGen/LiveIntervals.h"
20#include "llvm/CodeGen/MachineInstr.h"
21#include "llvm/CodeGen/MachineOperand.h"
22#include "llvm/CodeGen/MachineRegisterInfo.h"
23#include "llvm/CodeGen/RegisterPressure.h"
24#include "llvm/CodeGen/SlotIndexes.h"
25#include "llvm/CodeGen/TargetRegisterInfo.h"
26#include "llvm/Config/llvm-config.h"
27#include "llvm/MC/LaneBitmask.h"
28#include "llvm/Support/Compiler.h"
29#include "llvm/Support/Debug.h"
30#include "llvm/Support/ErrorHandling.h"
31#include "llvm/Support/raw_ostream.h"
32#include <algorithm>
33#include <cassert>
34
35using namespace llvm;
36
37#define DEBUG_TYPE "machine-scheduler"
38
39#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
40LLVM_DUMP_METHOD
41void llvm::printLivesAt(SlotIndex SI,
42                        const LiveIntervals &LIS,
43                        const MachineRegisterInfo &MRI) {
44  dbgs() << "Live regs at " << SI << ": "
45         << *LIS.getInstructionFromIndex(SI);
46  unsigned Num = 0;
47  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
48    const unsigned Reg = Register::index2VirtReg(I);
49    if (!LIS.hasInterval(Reg))
50      continue;
51    const auto &LI = LIS.getInterval(Reg);
52    if (LI.hasSubRanges()) {
53      bool firstTime = true;
54      for (const auto &S : LI.subranges()) {
55        if (!S.liveAt(SI)) continue;
56        if (firstTime) {
57          dbgs() << "  " << printReg(Reg, MRI.getTargetRegisterInfo())
58                 << '\n';
59          firstTime = false;
60        }
61        dbgs() << "  " << S << '\n';
62        ++Num;
63      }
64    } else if (LI.liveAt(SI)) {
65      dbgs() << "  " << LI << '\n';
66      ++Num;
67    }
68  }
69  if (!Num) dbgs() << "  <none>\n";
70}
71#endif
72
73bool llvm::isEqual(const GCNRPTracker::LiveRegSet &S1,
74                   const GCNRPTracker::LiveRegSet &S2) {
75  if (S1.size() != S2.size())
76    return false;
77
78  for (const auto &P : S1) {
79    auto I = S2.find(P.first);
80    if (I == S2.end() || I->second != P.second)
81      return false;
82  }
83  return true;
84}
85
86
87///////////////////////////////////////////////////////////////////////////////
88// GCNRegPressure
89
90unsigned GCNRegPressure::getRegKind(unsigned Reg,
91                                    const MachineRegisterInfo &MRI) {
92  assert(Register::isVirtualRegister(Reg));
93  const auto RC = MRI.getRegClass(Reg);
94  auto STI = static_cast<const SIRegisterInfo*>(MRI.getTargetRegisterInfo());
95  return STI->isSGPRClass(RC) ?
96    (STI->getRegSizeInBits(*RC) == 32 ? SGPR32 : SGPR_TUPLE) :
97    STI->hasAGPRs(RC) ?
98      (STI->getRegSizeInBits(*RC) == 32 ? AGPR32 : AGPR_TUPLE) :
99      (STI->getRegSizeInBits(*RC) == 32 ? VGPR32 : VGPR_TUPLE);
100}
101
102void GCNRegPressure::inc(unsigned Reg,
103                         LaneBitmask PrevMask,
104                         LaneBitmask NewMask,
105                         const MachineRegisterInfo &MRI) {
106  if (SIRegisterInfo::getNumCoveredRegs(NewMask) ==
107      SIRegisterInfo::getNumCoveredRegs(PrevMask))
108    return;
109
110  int Sign = 1;
111  if (NewMask < PrevMask) {
112    std::swap(NewMask, PrevMask);
113    Sign = -1;
114  }
115
116  switch (auto Kind = getRegKind(Reg, MRI)) {
117  case SGPR32:
118  case VGPR32:
119  case AGPR32:
120    Value[Kind] += Sign;
121    break;
122
123  case SGPR_TUPLE:
124  case VGPR_TUPLE:
125  case AGPR_TUPLE:
126    assert(PrevMask < NewMask);
127
128    Value[Kind == SGPR_TUPLE ? SGPR32 : Kind == AGPR_TUPLE ? AGPR32 : VGPR32] +=
129      Sign * SIRegisterInfo::getNumCoveredRegs(~PrevMask & NewMask);
130
131    if (PrevMask.none()) {
132      assert(NewMask.any());
133      Value[Kind] += Sign * MRI.getPressureSets(Reg).getWeight();
134    }
135    break;
136
137  default: llvm_unreachable("Unknown register kind");
138  }
139}
140
141bool GCNRegPressure::less(const GCNSubtarget &ST,
142                          const GCNRegPressure& O,
143                          unsigned MaxOccupancy) const {
144  const auto SGPROcc = std::min(MaxOccupancy,
145                                ST.getOccupancyWithNumSGPRs(getSGPRNum()));
146  const auto VGPROcc = std::min(MaxOccupancy,
147                                ST.getOccupancyWithNumVGPRs(getVGPRNum()));
148  const auto OtherSGPROcc = std::min(MaxOccupancy,
149                                ST.getOccupancyWithNumSGPRs(O.getSGPRNum()));
150  const auto OtherVGPROcc = std::min(MaxOccupancy,
151                                ST.getOccupancyWithNumVGPRs(O.getVGPRNum()));
152
153  const auto Occ = std::min(SGPROcc, VGPROcc);
154  const auto OtherOcc = std::min(OtherSGPROcc, OtherVGPROcc);
155  if (Occ != OtherOcc)
156    return Occ > OtherOcc;
157
158  bool SGPRImportant = SGPROcc < VGPROcc;
159  const bool OtherSGPRImportant = OtherSGPROcc < OtherVGPROcc;
160
161  // if both pressures disagree on what is more important compare vgprs
162  if (SGPRImportant != OtherSGPRImportant) {
163    SGPRImportant = false;
164  }
165
166  // compare large regs pressure
167  bool SGPRFirst = SGPRImportant;
168  for (int I = 2; I > 0; --I, SGPRFirst = !SGPRFirst) {
169    if (SGPRFirst) {
170      auto SW = getSGPRTuplesWeight();
171      auto OtherSW = O.getSGPRTuplesWeight();
172      if (SW != OtherSW)
173        return SW < OtherSW;
174    } else {
175      auto VW = getVGPRTuplesWeight();
176      auto OtherVW = O.getVGPRTuplesWeight();
177      if (VW != OtherVW)
178        return VW < OtherVW;
179    }
180  }
181  return SGPRImportant ? (getSGPRNum() < O.getSGPRNum()):
182                         (getVGPRNum() < O.getVGPRNum());
183}
184
185#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
186LLVM_DUMP_METHOD
187void GCNRegPressure::print(raw_ostream &OS, const GCNSubtarget *ST) const {
188  OS << "VGPRs: " << Value[VGPR32] << ' ';
189  OS << "AGPRs: " << Value[AGPR32];
190  if (ST) OS << "(O" << ST->getOccupancyWithNumVGPRs(getVGPRNum()) << ')';
191  OS << ", SGPRs: " << getSGPRNum();
192  if (ST) OS << "(O" << ST->getOccupancyWithNumSGPRs(getSGPRNum()) << ')';
193  OS << ", LVGPR WT: " << getVGPRTuplesWeight()
194     << ", LSGPR WT: " << getSGPRTuplesWeight();
195  if (ST) OS << " -> Occ: " << getOccupancy(*ST);
196  OS << '\n';
197}
198#endif
199
200static LaneBitmask getDefRegMask(const MachineOperand &MO,
201                                 const MachineRegisterInfo &MRI) {
202  assert(MO.isDef() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
203
204  // We don't rely on read-undef flag because in case of tentative schedule
205  // tracking it isn't set correctly yet. This works correctly however since
206  // use mask has been tracked before using LIS.
207  return MO.getSubReg() == 0 ?
208    MRI.getMaxLaneMaskForVReg(MO.getReg()) :
209    MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
210}
211
212static LaneBitmask getUsedRegMask(const MachineOperand &MO,
213                                  const MachineRegisterInfo &MRI,
214                                  const LiveIntervals &LIS) {
215  assert(MO.isUse() && MO.isReg() && Register::isVirtualRegister(MO.getReg()));
216
217  if (auto SubReg = MO.getSubReg())
218    return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
219
220  auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
221  if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
222    return MaxMask;
223
224  // For a tentative schedule LIS isn't updated yet but livemask should remain
225  // the same on any schedule. Subreg defs can be reordered but they all must
226  // dominate uses anyway.
227  auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
228  return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
229}
230
231static SmallVector<RegisterMaskPair, 8>
232collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
233                      const MachineRegisterInfo &MRI) {
234  SmallVector<RegisterMaskPair, 8> Res;
235  for (const auto &MO : MI.operands()) {
236    if (!MO.isReg() || !Register::isVirtualRegister(MO.getReg()))
237      continue;
238    if (!MO.isUse() || !MO.readsReg())
239      continue;
240
241    auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
242
243    auto Reg = MO.getReg();
244    auto I = std::find_if(Res.begin(), Res.end(), [Reg](const RegisterMaskPair &RM) {
245      return RM.RegUnit == Reg;
246    });
247    if (I != Res.end())
248      I->LaneMask |= UsedMask;
249    else
250      Res.push_back(RegisterMaskPair(Reg, UsedMask));
251  }
252  return Res;
253}
254
255///////////////////////////////////////////////////////////////////////////////
256// GCNRPTracker
257
258LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
259                                  SlotIndex SI,
260                                  const LiveIntervals &LIS,
261                                  const MachineRegisterInfo &MRI) {
262  LaneBitmask LiveMask;
263  const auto &LI = LIS.getInterval(Reg);
264  if (LI.hasSubRanges()) {
265    for (const auto &S : LI.subranges())
266      if (S.liveAt(SI)) {
267        LiveMask |= S.LaneMask;
268        assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
269               LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
270      }
271  } else if (LI.liveAt(SI)) {
272    LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
273  }
274  return LiveMask;
275}
276
277GCNRPTracker::LiveRegSet llvm::getLiveRegs(SlotIndex SI,
278                                           const LiveIntervals &LIS,
279                                           const MachineRegisterInfo &MRI) {
280  GCNRPTracker::LiveRegSet LiveRegs;
281  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
282    auto Reg = Register::index2VirtReg(I);
283    if (!LIS.hasInterval(Reg))
284      continue;
285    auto LiveMask = getLiveLaneMask(Reg, SI, LIS, MRI);
286    if (LiveMask.any())
287      LiveRegs[Reg] = LiveMask;
288  }
289  return LiveRegs;
290}
291
292void GCNRPTracker::reset(const MachineInstr &MI,
293                         const LiveRegSet *LiveRegsCopy,
294                         bool After) {
295  const MachineFunction &MF = *MI.getMF();
296  MRI = &MF.getRegInfo();
297  if (LiveRegsCopy) {
298    if (&LiveRegs != LiveRegsCopy)
299      LiveRegs = *LiveRegsCopy;
300  } else {
301    LiveRegs = After ? getLiveRegsAfter(MI, LIS)
302                     : getLiveRegsBefore(MI, LIS);
303  }
304
305  MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
306}
307
308void GCNUpwardRPTracker::reset(const MachineInstr &MI,
309                               const LiveRegSet *LiveRegsCopy) {
310  GCNRPTracker::reset(MI, LiveRegsCopy, true);
311}
312
313void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
314  assert(MRI && "call reset first");
315
316  LastTrackedMI = &MI;
317
318  if (MI.isDebugInstr())
319    return;
320
321  auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
322
323  // calc pressure at the MI (defs + uses)
324  auto AtMIPressure = CurPressure;
325  for (const auto &U : RegUses) {
326    auto LiveMask = LiveRegs[U.RegUnit];
327    AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
328  }
329  // update max pressure
330  MaxPressure = max(AtMIPressure, MaxPressure);
331
332  for (const auto &MO : MI.operands()) {
333    if (!MO.isReg() || !MO.isDef() ||
334        !Register::isVirtualRegister(MO.getReg()) || MO.isDead())
335      continue;
336
337    auto Reg = MO.getReg();
338    auto I = LiveRegs.find(Reg);
339    if (I == LiveRegs.end())
340      continue;
341    auto &LiveMask = I->second;
342    auto PrevMask = LiveMask;
343    LiveMask &= ~getDefRegMask(MO, *MRI);
344    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
345    if (LiveMask.none())
346      LiveRegs.erase(I);
347  }
348  for (const auto &U : RegUses) {
349    auto &LiveMask = LiveRegs[U.RegUnit];
350    auto PrevMask = LiveMask;
351    LiveMask |= U.LaneMask;
352    CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
353  }
354  assert(CurPressure == getRegPressure(*MRI, LiveRegs));
355}
356
357bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
358                                 const LiveRegSet *LiveRegsCopy) {
359  MRI = &MI.getParent()->getParent()->getRegInfo();
360  LastTrackedMI = nullptr;
361  MBBEnd = MI.getParent()->end();
362  NextMI = &MI;
363  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
364  if (NextMI == MBBEnd)
365    return false;
366  GCNRPTracker::reset(*NextMI, LiveRegsCopy, false);
367  return true;
368}
369
370bool GCNDownwardRPTracker::advanceBeforeNext() {
371  assert(MRI && "call reset first");
372
373  NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
374  if (NextMI == MBBEnd)
375    return false;
376
377  SlotIndex SI = LIS.getInstructionIndex(*NextMI).getBaseIndex();
378  assert(SI.isValid());
379
380  // Remove dead registers or mask bits.
381  for (auto &It : LiveRegs) {
382    const LiveInterval &LI = LIS.getInterval(It.first);
383    if (LI.hasSubRanges()) {
384      for (const auto &S : LI.subranges()) {
385        if (!S.liveAt(SI)) {
386          auto PrevMask = It.second;
387          It.second &= ~S.LaneMask;
388          CurPressure.inc(It.first, PrevMask, It.second, *MRI);
389        }
390      }
391    } else if (!LI.liveAt(SI)) {
392      auto PrevMask = It.second;
393      It.second = LaneBitmask::getNone();
394      CurPressure.inc(It.first, PrevMask, It.second, *MRI);
395    }
396    if (It.second.none())
397      LiveRegs.erase(It.first);
398  }
399
400  MaxPressure = max(MaxPressure, CurPressure);
401
402  return true;
403}
404
405void GCNDownwardRPTracker::advanceToNext() {
406  LastTrackedMI = &*NextMI++;
407
408  // Add new registers or mask bits.
409  for (const auto &MO : LastTrackedMI->operands()) {
410    if (!MO.isReg() || !MO.isDef())
411      continue;
412    Register Reg = MO.getReg();
413    if (!Register::isVirtualRegister(Reg))
414      continue;
415    auto &LiveMask = LiveRegs[Reg];
416    auto PrevMask = LiveMask;
417    LiveMask |= getDefRegMask(MO, *MRI);
418    CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
419  }
420
421  MaxPressure = max(MaxPressure, CurPressure);
422}
423
424bool GCNDownwardRPTracker::advance() {
425  // If we have just called reset live set is actual.
426  if ((NextMI == MBBEnd) || (LastTrackedMI && !advanceBeforeNext()))
427    return false;
428  advanceToNext();
429  return true;
430}
431
432bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator End) {
433  while (NextMI != End)
434    if (!advance()) return false;
435  return true;
436}
437
438bool GCNDownwardRPTracker::advance(MachineBasicBlock::const_iterator Begin,
439                                   MachineBasicBlock::const_iterator End,
440                                   const LiveRegSet *LiveRegsCopy) {
441  reset(*Begin, LiveRegsCopy);
442  return advance(End);
443}
444
445#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
446LLVM_DUMP_METHOD
447static void reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
448                           const GCNRPTracker::LiveRegSet &TrackedLR,
449                           const TargetRegisterInfo *TRI) {
450  for (auto const &P : TrackedLR) {
451    auto I = LISLR.find(P.first);
452    if (I == LISLR.end()) {
453      dbgs() << "  " << printReg(P.first, TRI)
454             << ":L" << PrintLaneMask(P.second)
455             << " isn't found in LIS reported set\n";
456    }
457    else if (I->second != P.second) {
458      dbgs() << "  " << printReg(P.first, TRI)
459        << " masks doesn't match: LIS reported "
460        << PrintLaneMask(I->second)
461        << ", tracked "
462        << PrintLaneMask(P.second)
463        << '\n';
464    }
465  }
466  for (auto const &P : LISLR) {
467    auto I = TrackedLR.find(P.first);
468    if (I == TrackedLR.end()) {
469      dbgs() << "  " << printReg(P.first, TRI)
470             << ":L" << PrintLaneMask(P.second)
471             << " isn't found in tracked set\n";
472    }
473  }
474}
475
476bool GCNUpwardRPTracker::isValid() const {
477  const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
478  const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);
479  const auto &TrackedLR = LiveRegs;
480
481  if (!isEqual(LISLR, TrackedLR)) {
482    dbgs() << "\nGCNUpwardRPTracker error: Tracked and"
483              " LIS reported livesets mismatch:\n";
484    printLivesAt(SI, LIS, *MRI);
485    reportMismatch(LISLR, TrackedLR, MRI->getTargetRegisterInfo());
486    return false;
487  }
488
489  auto LISPressure = getRegPressure(*MRI, LISLR);
490  if (LISPressure != CurPressure) {
491    dbgs() << "GCNUpwardRPTracker error: Pressure sets different\nTracked: ";
492    CurPressure.print(dbgs());
493    dbgs() << "LIS rpt: ";
494    LISPressure.print(dbgs());
495    return false;
496  }
497  return true;
498}
499
500void GCNRPTracker::printLiveRegs(raw_ostream &OS, const LiveRegSet& LiveRegs,
501                                 const MachineRegisterInfo &MRI) {
502  const TargetRegisterInfo *TRI = MRI.getTargetRegisterInfo();
503  for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) {
504    unsigned Reg = Register::index2VirtReg(I);
505    auto It = LiveRegs.find(Reg);
506    if (It != LiveRegs.end() && It->second.any())
507      OS << ' ' << printVRegOrUnit(Reg, TRI) << ':'
508         << PrintLaneMask(It->second);
509  }
510  OS << '\n';
511}
512#endif
513