1//===- RDFRegisters.cpp ---------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "llvm/ADT/BitVector.h"
10#include "llvm/CodeGen/MachineFunction.h"
11#include "llvm/CodeGen/MachineInstr.h"
12#include "llvm/CodeGen/MachineOperand.h"
13#include "llvm/CodeGen/RDFRegisters.h"
14#include "llvm/CodeGen/TargetRegisterInfo.h"
15#include "llvm/MC/LaneBitmask.h"
16#include "llvm/MC/MCRegisterInfo.h"
17#include "llvm/Support/ErrorHandling.h"
18#include "llvm/Support/raw_ostream.h"
19#include <cassert>
20#include <cstdint>
21#include <set>
22#include <utility>
23
24using namespace llvm;
25using namespace rdf;
26
27PhysicalRegisterInfo::PhysicalRegisterInfo(const TargetRegisterInfo &tri,
28      const MachineFunction &mf)
29    : TRI(tri) {
30  RegInfos.resize(TRI.getNumRegs());
31
32  BitVector BadRC(TRI.getNumRegs());
33  for (const TargetRegisterClass *RC : TRI.regclasses()) {
34    for (MCPhysReg R : *RC) {
35      RegInfo &RI = RegInfos[R];
36      if (RI.RegClass != nullptr && !BadRC[R]) {
37        if (RC->LaneMask != RI.RegClass->LaneMask) {
38          BadRC.set(R);
39          RI.RegClass = nullptr;
40        }
41      } else
42        RI.RegClass = RC;
43    }
44  }
45
46  UnitInfos.resize(TRI.getNumRegUnits());
47
48  for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
49    if (UnitInfos[U].Reg != 0)
50      continue;
51    MCRegUnitRootIterator R(U, &TRI);
52    assert(R.isValid());
53    RegisterId F = *R;
54    ++R;
55    if (R.isValid()) {
56      UnitInfos[U].Mask = LaneBitmask::getAll();
57      UnitInfos[U].Reg = F;
58    } else {
59      for (MCRegUnitMaskIterator I(F, &TRI); I.isValid(); ++I) {
60        std::pair<uint32_t,LaneBitmask> P = *I;
61        UnitInfo &UI = UnitInfos[P.first];
62        UI.Reg = F;
63        if (P.second.any()) {
64          UI.Mask = P.second;
65        } else {
66          if (const TargetRegisterClass *RC = RegInfos[F].RegClass)
67            UI.Mask = RC->LaneMask;
68          else
69            UI.Mask = LaneBitmask::getAll();
70        }
71      }
72    }
73  }
74
75  for (const uint32_t *RM : TRI.getRegMasks())
76    RegMasks.insert(RM);
77  for (const MachineBasicBlock &B : mf)
78    for (const MachineInstr &In : B)
79      for (const MachineOperand &Op : In.operands())
80        if (Op.isRegMask())
81          RegMasks.insert(Op.getRegMask());
82
83  MaskInfos.resize(RegMasks.size()+1);
84  for (uint32_t M = 1, NM = RegMasks.size(); M <= NM; ++M) {
85    BitVector PU(TRI.getNumRegUnits());
86    const uint32_t *MB = RegMasks.get(M);
87    for (unsigned I = 1, E = TRI.getNumRegs(); I != E; ++I) {
88      if (!(MB[I / 32] & (1u << (I % 32))))
89        continue;
90      for (MCRegUnitIterator U(MCRegister::from(I), &TRI); U.isValid(); ++U)
91        PU.set(*U);
92    }
93    MaskInfos[M].Units = PU.flip();
94  }
95
96  AliasInfos.resize(TRI.getNumRegUnits());
97  for (uint32_t U = 0, NU = TRI.getNumRegUnits(); U != NU; ++U) {
98    BitVector AS(TRI.getNumRegs());
99    for (MCRegUnitRootIterator R(U, &TRI); R.isValid(); ++R)
100      for (MCSuperRegIterator S(*R, &TRI, true); S.isValid(); ++S)
101        AS.set(*S);
102    AliasInfos[U].Regs = AS;
103  }
104}
105
106std::set<RegisterId> PhysicalRegisterInfo::getAliasSet(RegisterId Reg) const {
107  // Do not include RR in the alias set.
108  std::set<RegisterId> AS;
109  assert(isRegMaskId(Reg) || Register::isPhysicalRegister(Reg));
110  if (isRegMaskId(Reg)) {
111    // XXX SLOW
112    const uint32_t *MB = getRegMaskBits(Reg);
113    for (unsigned i = 1, e = TRI.getNumRegs(); i != e; ++i) {
114      if (MB[i/32] & (1u << (i%32)))
115        continue;
116      AS.insert(i);
117    }
118    for (const uint32_t *RM : RegMasks) {
119      RegisterId MI = getRegMaskId(RM);
120      if (MI != Reg && aliasMM(RegisterRef(Reg), RegisterRef(MI)))
121        AS.insert(MI);
122    }
123    return AS;
124  }
125
126  for (MCRegAliasIterator AI(Reg, &TRI, false); AI.isValid(); ++AI)
127    AS.insert(*AI);
128  for (const uint32_t *RM : RegMasks) {
129    RegisterId MI = getRegMaskId(RM);
130    if (aliasRM(RegisterRef(Reg), RegisterRef(MI)))
131      AS.insert(MI);
132  }
133  return AS;
134}
135
136bool PhysicalRegisterInfo::aliasRR(RegisterRef RA, RegisterRef RB) const {
137  assert(Register::isPhysicalRegister(RA.Reg));
138  assert(Register::isPhysicalRegister(RB.Reg));
139
140  MCRegUnitMaskIterator UMA(RA.Reg, &TRI);
141  MCRegUnitMaskIterator UMB(RB.Reg, &TRI);
142  // Reg units are returned in the numerical order.
143  while (UMA.isValid() && UMB.isValid()) {
144    // Skip units that are masked off in RA.
145    std::pair<RegisterId,LaneBitmask> PA = *UMA;
146    if (PA.second.any() && (PA.second & RA.Mask).none()) {
147      ++UMA;
148      continue;
149    }
150    // Skip units that are masked off in RB.
151    std::pair<RegisterId,LaneBitmask> PB = *UMB;
152    if (PB.second.any() && (PB.second & RB.Mask).none()) {
153      ++UMB;
154      continue;
155    }
156
157    if (PA.first == PB.first)
158      return true;
159    if (PA.first < PB.first)
160      ++UMA;
161    else if (PB.first < PA.first)
162      ++UMB;
163  }
164  return false;
165}
166
167bool PhysicalRegisterInfo::aliasRM(RegisterRef RR, RegisterRef RM) const {
168  assert(Register::isPhysicalRegister(RR.Reg) && isRegMaskId(RM.Reg));
169  const uint32_t *MB = getRegMaskBits(RM.Reg);
170  bool Preserved = MB[RR.Reg/32] & (1u << (RR.Reg%32));
171  // If the lane mask information is "full", e.g. when the given lane mask
172  // is a superset of the lane mask from the register class, check the regmask
173  // bit directly.
174  if (RR.Mask == LaneBitmask::getAll())
175    return !Preserved;
176  const TargetRegisterClass *RC = RegInfos[RR.Reg].RegClass;
177  if (RC != nullptr && (RR.Mask & RC->LaneMask) == RC->LaneMask)
178    return !Preserved;
179
180  // Otherwise, check all subregisters whose lane mask overlaps the given
181  // mask. For each such register, if it is preserved by the regmask, then
182  // clear the corresponding bits in the given mask. If at the end, all
183  // bits have been cleared, the register does not alias the regmask (i.e.
184  // is it preserved by it).
185  LaneBitmask M = RR.Mask;
186  for (MCSubRegIndexIterator SI(RR.Reg, &TRI); SI.isValid(); ++SI) {
187    LaneBitmask SM = TRI.getSubRegIndexLaneMask(SI.getSubRegIndex());
188    if ((SM & RR.Mask).none())
189      continue;
190    unsigned SR = SI.getSubReg();
191    if (!(MB[SR/32] & (1u << (SR%32))))
192      continue;
193    // The subregister SR is preserved.
194    M &= ~SM;
195    if (M.none())
196      return false;
197  }
198
199  return true;
200}
201
202bool PhysicalRegisterInfo::aliasMM(RegisterRef RM, RegisterRef RN) const {
203  assert(isRegMaskId(RM.Reg) && isRegMaskId(RN.Reg));
204  unsigned NumRegs = TRI.getNumRegs();
205  const uint32_t *BM = getRegMaskBits(RM.Reg);
206  const uint32_t *BN = getRegMaskBits(RN.Reg);
207
208  for (unsigned w = 0, nw = NumRegs/32; w != nw; ++w) {
209    // Intersect the negations of both words. Disregard reg=0,
210    // i.e. 0th bit in the 0th word.
211    uint32_t C = ~BM[w] & ~BN[w];
212    if (w == 0)
213      C &= ~1;
214    if (C)
215      return true;
216  }
217
218  // Check the remaining registers in the last word.
219  unsigned TailRegs = NumRegs % 32;
220  if (TailRegs == 0)
221    return false;
222  unsigned TW = NumRegs / 32;
223  uint32_t TailMask = (1u << TailRegs) - 1;
224  if (~BM[TW] & ~BN[TW] & TailMask)
225    return true;
226
227  return false;
228}
229
230RegisterRef PhysicalRegisterInfo::mapTo(RegisterRef RR, unsigned R) const {
231  if (RR.Reg == R)
232    return RR;
233  if (unsigned Idx = TRI.getSubRegIndex(R, RR.Reg))
234    return RegisterRef(R, TRI.composeSubRegIndexLaneMask(Idx, RR.Mask));
235  if (unsigned Idx = TRI.getSubRegIndex(RR.Reg, R)) {
236    const RegInfo &RI = RegInfos[R];
237    LaneBitmask RCM = RI.RegClass ? RI.RegClass->LaneMask
238                                  : LaneBitmask::getAll();
239    LaneBitmask M = TRI.reverseComposeSubRegIndexLaneMask(Idx, RR.Mask);
240    return RegisterRef(R, M & RCM);
241  }
242  llvm_unreachable("Invalid arguments: unrelated registers?");
243}
244
245bool RegisterAggr::hasAliasOf(RegisterRef RR) const {
246  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg))
247    return Units.anyCommon(PRI.getMaskUnits(RR.Reg));
248
249  for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
250    std::pair<uint32_t,LaneBitmask> P = *U;
251    if (P.second.none() || (P.second & RR.Mask).any())
252      if (Units.test(P.first))
253        return true;
254  }
255  return false;
256}
257
258bool RegisterAggr::hasCoverOf(RegisterRef RR) const {
259  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
260    BitVector T(PRI.getMaskUnits(RR.Reg));
261    return T.reset(Units).none();
262  }
263
264  for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
265    std::pair<uint32_t,LaneBitmask> P = *U;
266    if (P.second.none() || (P.second & RR.Mask).any())
267      if (!Units.test(P.first))
268        return false;
269  }
270  return true;
271}
272
273RegisterAggr &RegisterAggr::insert(RegisterRef RR) {
274  if (PhysicalRegisterInfo::isRegMaskId(RR.Reg)) {
275    Units |= PRI.getMaskUnits(RR.Reg);
276    return *this;
277  }
278
279  for (MCRegUnitMaskIterator U(RR.Reg, &PRI.getTRI()); U.isValid(); ++U) {
280    std::pair<uint32_t,LaneBitmask> P = *U;
281    if (P.second.none() || (P.second & RR.Mask).any())
282      Units.set(P.first);
283  }
284  return *this;
285}
286
287RegisterAggr &RegisterAggr::insert(const RegisterAggr &RG) {
288  Units |= RG.Units;
289  return *this;
290}
291
292RegisterAggr &RegisterAggr::intersect(RegisterRef RR) {
293  return intersect(RegisterAggr(PRI).insert(RR));
294}
295
296RegisterAggr &RegisterAggr::intersect(const RegisterAggr &RG) {
297  Units &= RG.Units;
298  return *this;
299}
300
301RegisterAggr &RegisterAggr::clear(RegisterRef RR) {
302  return clear(RegisterAggr(PRI).insert(RR));
303}
304
305RegisterAggr &RegisterAggr::clear(const RegisterAggr &RG) {
306  Units.reset(RG.Units);
307  return *this;
308}
309
310RegisterRef RegisterAggr::intersectWith(RegisterRef RR) const {
311  RegisterAggr T(PRI);
312  T.insert(RR).intersect(*this);
313  if (T.empty())
314    return RegisterRef();
315  RegisterRef NR = T.makeRegRef();
316  assert(NR);
317  return NR;
318}
319
320RegisterRef RegisterAggr::clearIn(RegisterRef RR) const {
321  return RegisterAggr(PRI).insert(RR).clear(*this).makeRegRef();
322}
323
324RegisterRef RegisterAggr::makeRegRef() const {
325  int U = Units.find_first();
326  if (U < 0)
327    return RegisterRef();
328
329  // Find the set of all registers that are aliased to all the units
330  // in this aggregate.
331
332  // Get all the registers aliased to the first unit in the bit vector.
333  BitVector Regs = PRI.getUnitAliases(U);
334  U = Units.find_next(U);
335
336  // For each other unit, intersect it with the set of all registers
337  // aliased that unit.
338  while (U >= 0) {
339    Regs &= PRI.getUnitAliases(U);
340    U = Units.find_next(U);
341  }
342
343  // If there is at least one register remaining, pick the first one,
344  // and consolidate the masks of all of its units contained in this
345  // aggregate.
346
347  int F = Regs.find_first();
348  if (F <= 0)
349    return RegisterRef();
350
351  LaneBitmask M;
352  for (MCRegUnitMaskIterator I(F, &PRI.getTRI()); I.isValid(); ++I) {
353    std::pair<uint32_t,LaneBitmask> P = *I;
354    if (Units.test(P.first))
355      M |= P.second.none() ? LaneBitmask::getAll() : P.second;
356  }
357  return RegisterRef(F, M);
358}
359
360void RegisterAggr::print(raw_ostream &OS) const {
361  OS << '{';
362  for (int U = Units.find_first(); U >= 0; U = Units.find_next(U))
363    OS << ' ' << printRegUnit(U, &PRI.getTRI());
364  OS << " }";
365}
366
367RegisterAggr::rr_iterator::rr_iterator(const RegisterAggr &RG,
368      bool End)
369    : Owner(&RG) {
370  for (int U = RG.Units.find_first(); U >= 0; U = RG.Units.find_next(U)) {
371    RegisterRef R = RG.PRI.getRefForUnit(U);
372    Masks[R.Reg] |= R.Mask;
373  }
374  Pos = End ? Masks.end() : Masks.begin();
375  Index = End ? Masks.size() : 0;
376}
377
378raw_ostream &rdf::operator<<(raw_ostream &OS, const RegisterAggr &A) {
379  A.print(OS);
380  return OS;
381}
382